Experiments look good
This commit is contained in:
@ -15,8 +15,8 @@ from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
from environments.helpers import EnvActions as a
|
||||
from environments.helpers import Rewards as r
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \
|
||||
from environments.factory.base.objects import Agent, Floor, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \
|
||||
GlobalPositions
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
|
||||
from environments.utility_classes import AgentRenderOptions as a_obs
|
||||
@ -121,7 +121,7 @@ class BaseFactory(gym.Env):
|
||||
self.doors_have_area = doors_have_area
|
||||
self.individual_rewards = individual_rewards
|
||||
|
||||
# Reset
|
||||
# TODO: Reset ---> document this
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, item):
|
||||
@ -141,21 +141,21 @@ class BaseFactory(gym.Env):
|
||||
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
|
||||
|
||||
# Walls
|
||||
walls = WallTiles.from_argwhere_coordinates(
|
||||
walls = Walls.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = FloorTiles.from_argwhere_coordinates(
|
||||
floor = Floors.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.FREE_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Tile(c.NO_POS, None)
|
||||
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||
|
||||
# Doors
|
||||
if self.parse_doors:
|
||||
@ -170,7 +170,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Actions
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.additional_actions:
|
||||
if additional_actions := self.actions_hook:
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
|
||||
# Agents
|
||||
@ -202,7 +202,7 @@ class BaseFactory(gym.Env):
|
||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# Additional Entitites from SubEnvs
|
||||
if additional_entities := self.additional_entities:
|
||||
if additional_entities := self.entities_hook:
|
||||
self._entities.register_additional_items(additional_entities)
|
||||
|
||||
if self.obs_prop.show_global_position_info:
|
||||
@ -217,7 +217,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def reset(self) -> (np.typing.ArrayLike, int, bool, dict):
|
||||
_ = self._base_init_env()
|
||||
self.do_additional_reset()
|
||||
self.reset_hook()
|
||||
|
||||
self._steps = 0
|
||||
|
||||
@ -233,7 +233,7 @@ class BaseFactory(gym.Env):
|
||||
self._steps += 1
|
||||
|
||||
# Pre step Hook for later use
|
||||
self.hook_pre_step()
|
||||
self.pre_step_hook()
|
||||
|
||||
for action, agent in zip(actions, self[c.AGENT]):
|
||||
agent.clear_temp_state()
|
||||
@ -244,7 +244,7 @@ class BaseFactory(gym.Env):
|
||||
action_valid, reward = self._do_move_action(agent, action_obj)
|
||||
elif a.NOOP == action_obj:
|
||||
action_valid = c.VALID
|
||||
reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.pos}_NOOP': 1})
|
||||
reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1})
|
||||
elif a.USE_DOOR == action_obj:
|
||||
action_valid, reward = self._handle_door_interaction(agent)
|
||||
else:
|
||||
@ -258,7 +258,7 @@ class BaseFactory(gym.Env):
|
||||
agent.step_result = step_result
|
||||
|
||||
# Additional step and Reward, Info Init
|
||||
rewards, info = self.do_additional_step()
|
||||
rewards, info = self.step_hook()
|
||||
# Todo: Make this faster, so that only tiles of entities that can collide are searched.
|
||||
tiles_with_collisions = self.get_all_tiles_with_collisions()
|
||||
for tile in tiles_with_collisions:
|
||||
@ -297,7 +297,7 @@ class BaseFactory(gym.Env):
|
||||
info.update(self._summarize_state())
|
||||
|
||||
# Post step Hook for later use
|
||||
info.update(self.hook_post_step())
|
||||
info.update(self.post_step_hook())
|
||||
|
||||
obs, _ = self._build_observations()
|
||||
|
||||
@ -314,11 +314,11 @@ class BaseFactory(gym.Env):
|
||||
door.use()
|
||||
valid = c.VALID
|
||||
self.print(f'{agent.name} just used a {door.name} at {door.pos}')
|
||||
info_dict = {f'{agent.name}_door_use': 1}
|
||||
info_dict = {f'{agent.name}_door_use': 1, f'door_use': 1}
|
||||
# When he doesn't...
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_failed_door_use': 1}
|
||||
info_dict = {f'{agent.name}_failed_door_use': 1, 'failed_door_use': 1}
|
||||
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.')
|
||||
|
||||
else:
|
||||
@ -334,7 +334,7 @@ class BaseFactory(gym.Env):
|
||||
per_agent_obsn = dict()
|
||||
# Generel Observations
|
||||
lvl_obs = self[c.WALLS].as_array()
|
||||
door_obs = self[c.DOORS].as_array()
|
||||
door_obs = self[c.DOORS].as_array() if self.parse_doors else None
|
||||
if self.obs_prop.render_agents == a_obs.NOT:
|
||||
global_agent_obs = None
|
||||
elif self.obs_prop.omit_agent_self and self.n_agents == 1:
|
||||
@ -342,7 +342,7 @@ class BaseFactory(gym.Env):
|
||||
else:
|
||||
global_agent_obs = self[c.AGENT].as_array().copy()
|
||||
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
|
||||
add_obs_dict = self._additional_observations()
|
||||
add_obs_dict = self.observations_hook()
|
||||
|
||||
for agent_idx, agent in enumerate(self[c.AGENT]):
|
||||
obs_dict = dict()
|
||||
@ -367,17 +367,17 @@ class BaseFactory(gym.Env):
|
||||
|
||||
obs_dict[c.WALLS] = lvl_obs
|
||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||
obs_dict[c.AGENT] = agent_obs
|
||||
obs_dict[c.AGENT] = agent_obs[:]
|
||||
if self[c.AGENT_PLACEHOLDER] and placeholder_obs is not None:
|
||||
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
|
||||
if self.parse_doors and door_obs is not None:
|
||||
obs_dict[c.DOORS] = door_obs
|
||||
obs_dict[c.DOORS] = door_obs[:]
|
||||
obs_dict.update(add_obs_dict)
|
||||
obsn = np.vstack(list(obs_dict.values()))
|
||||
if self.obs_prop.pomdp_r:
|
||||
obsn = self._do_pomdp_cutout(agent, obsn)
|
||||
|
||||
raw_obs = self._additional_per_agent_raw_observations(agent)
|
||||
raw_obs = self.per_agent_raw_observations_hook(agent)
|
||||
raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()}
|
||||
obsn = np.vstack((obsn, *raw_obs.values()))
|
||||
|
||||
@ -387,6 +387,12 @@ class BaseFactory(gym.Env):
|
||||
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
|
||||
|
||||
# Shadow Casting
|
||||
if agent.step_result is not None:
|
||||
pass
|
||||
else:
|
||||
assert self._steps == 0
|
||||
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
|
||||
'collisions': [], 'lightmap': None}
|
||||
if self.obs_prop.cast_shadows:
|
||||
try:
|
||||
light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
|
||||
@ -430,17 +436,15 @@ class BaseFactory(gym.Env):
|
||||
if door_shadowing:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
light_block_map[xs, ys] = 0
|
||||
if agent.step_result:
|
||||
agent.step_result['lightmap'] = light_block_map
|
||||
pass
|
||||
else:
|
||||
assert self._steps == 0
|
||||
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
|
||||
'collisions': [], 'lightmap': light_block_map}
|
||||
|
||||
agent.step_result['lightmap'] = light_block_map
|
||||
|
||||
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
|
||||
else:
|
||||
pass
|
||||
if self._pomdp_r:
|
||||
agent.step_result['lightmap'] = np.ones(self._obs_shape)
|
||||
else:
|
||||
agent.step_result['lightmap'] = None
|
||||
|
||||
per_agent_obsn[agent.name] = obsn
|
||||
|
||||
@ -484,7 +488,7 @@ class BaseFactory(gym.Env):
|
||||
oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant')
|
||||
return oobs
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
if False:
|
||||
tiles_with_collisions = list()
|
||||
@ -503,22 +507,22 @@ class BaseFactory(gym.Env):
|
||||
valid = agent.move(new_tile)
|
||||
if valid:
|
||||
# This will spam your logs, beware!
|
||||
# self.print(f'{agent.name} just moved from {agent.last_pos} to {agent.pos}.')
|
||||
# info_dict.update({f'{agent.pos}_move': 1})
|
||||
self.print(f'{agent.name} just moved {action.identifier} from {agent.last_pos} to {agent.pos}.')
|
||||
info_dict.update({f'{agent.name}_move': 1, 'move': 1})
|
||||
pass
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just hit the wall at {agent.pos}.')
|
||||
info_dict.update({f'{agent.name}_wall_collide': 1})
|
||||
self.print(f'{agent.name} just hit the wall at {agent.pos}. ({action.identifier})')
|
||||
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||
else:
|
||||
# Agent seems to be trying to Leave the level
|
||||
self.print(f'{agent.name} tried to leave the level {agent.pos}.')
|
||||
info_dict.update({f'{agent.name}_wall_collide': 1})
|
||||
self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})')
|
||||
info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
|
||||
reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
|
||||
return valid, reward
|
||||
|
||||
def _check_agent_move(self, agent, action: Action) -> (Tile, bool):
|
||||
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||
# Actions
|
||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
||||
x_new = agent.x + x_diff
|
||||
@ -556,10 +560,6 @@ class BaseFactory(gym.Env):
|
||||
|
||||
return new_tile, valid
|
||||
|
||||
@abc.abstractmethod
|
||||
def additional_per_agent_rewards(self, agent) -> List[dict]:
|
||||
return []
|
||||
|
||||
def build_reward_result(self, global_env_rewards: list) -> (int, dict):
|
||||
# Returns: Reward, Info
|
||||
info = defaultdict(lambda: 0.0)
|
||||
@ -567,7 +567,7 @@ class BaseFactory(gym.Env):
|
||||
# Gather additional sub-env rewards and calculate collisions
|
||||
for agent in self[c.AGENT]:
|
||||
|
||||
rewards = self.additional_per_agent_rewards(agent)
|
||||
rewards = self.per_agent_reward_hook(agent)
|
||||
for reward in rewards:
|
||||
agent.step_result['rewards'].append(reward)
|
||||
if collisions := agent.step_result['collisions']:
|
||||
@ -601,6 +601,12 @@ class BaseFactory(gym.Env):
|
||||
self.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict
|
||||
|
||||
def start_recording(self):
|
||||
self._record_episodes = True
|
||||
|
||||
def stop_recording(self):
|
||||
self._record_episodes = False
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
@ -621,7 +627,7 @@ class BaseFactory(gym.Env):
|
||||
for i, door in enumerate(self[c.DOORS]):
|
||||
name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
|
||||
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
|
||||
additional_assets = self.render_additional_assets()
|
||||
additional_assets = self.render_assets_hook()
|
||||
|
||||
return self._renderer.render(walls + doors + additional_assets + agents)
|
||||
|
||||
@ -652,7 +658,8 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Properties which are called by the base class to extend beyond attributes of the base class
|
||||
@property
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
@abc.abstractmethod
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
"""
|
||||
When heriting from this Base Class, you musst implement this methode!!!
|
||||
|
||||
@ -662,7 +669,8 @@ class BaseFactory(gym.Env):
|
||||
return []
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(str, Entities)]:
|
||||
@abc.abstractmethod
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
"""
|
||||
When heriting from this Base Class, you musst implement this methode!!!
|
||||
|
||||
@ -674,27 +682,39 @@ class BaseFactory(gym.Env):
|
||||
# Functions which provide additions to functions of the base class
|
||||
# Always call super!!!!!!
|
||||
@abc.abstractmethod
|
||||
def do_additional_reset(self) -> None:
|
||||
def reset_hook(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
return [], {}
|
||||
def pre_step_hook(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
return [], {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
return False, {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def post_step_hook(self) -> dict:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = {}
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_pos_obs = np.zeros(self._obs_shape)
|
||||
@ -703,19 +723,5 @@ class BaseFactory(gym.Env):
|
||||
return additional_raw_observations
|
||||
|
||||
@abc.abstractmethod
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def render_additional_assets(self):
|
||||
def render_assets_hook(self):
|
||||
return []
|
||||
|
||||
# Hooks for in between operations.
|
||||
# Always call super!!!!!!
|
||||
@abc.abstractmethod
|
||||
def hook_pre_step(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def hook_post_step(self) -> dict:
|
||||
return {}
|
||||
|
@ -9,10 +9,11 @@ from environments.helpers import Constants as c
|
||||
import itertools
|
||||
|
||||
##########################################################################
|
||||
# ##################### Base Object Definition ######################### #
|
||||
# ##################### Base Object Building Blocks ######################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class Object:
|
||||
|
||||
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||
@ -53,8 +54,10 @@ class Object:
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return other == self.identifier
|
||||
# Base
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class EnvObject(Object):
|
||||
|
||||
"""Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
|
||||
@ -78,27 +81,10 @@ class EnvObject(Object):
|
||||
self._register.delete_env_object(self)
|
||||
self._register = register
|
||||
return self._register == register
|
||||
# With Rendering
|
||||
|
||||
|
||||
class BoundingMixin(Object):
|
||||
|
||||
@property
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
def __init__(self,entity_to_be_bound, *args, **kwargs):
|
||||
super(BoundingMixin, self).__init__(*args, **kwargs)
|
||||
assert entity_to_be_bound is not None
|
||||
self._bound_entity = entity_to_be_bound
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return entity == self.bound_entity
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class Entity(EnvObject):
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
||||
|
||||
@ -133,8 +119,10 @@ class Entity(EnvObject):
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||
# With Position in Env
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class MoveableEntity(Entity):
|
||||
|
||||
@property
|
||||
@ -169,6 +157,27 @@ class MoveableEntity(Entity):
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
# Can Move
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class BoundingMixin(Object):
|
||||
|
||||
@property
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
def __init__(self,entity_to_be_bound, *args, **kwargs):
|
||||
super(BoundingMixin, self).__init__(*args, **kwargs)
|
||||
assert entity_to_be_bound is not None
|
||||
self._bound_entity = entity_to_be_bound
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return entity == self.bound_entity
|
||||
|
||||
|
||||
##########################################################################
|
||||
@ -216,7 +225,7 @@ class GlobalPosition(BoundingMixin, EnvObject):
|
||||
self._normalized = normalized
|
||||
|
||||
|
||||
class Tile(EnvObject):
|
||||
class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
@ -243,7 +252,7 @@ class Tile(EnvObject):
|
||||
return self._pos
|
||||
|
||||
def __init__(self, pos, *args, **kwargs):
|
||||
super(Tile, self).__init__(*args, **kwargs)
|
||||
super(Floor, self).__init__(*args, **kwargs)
|
||||
self._guests = dict()
|
||||
self._pos = tuple(pos)
|
||||
|
||||
@ -277,7 +286,7 @@ class Tile(EnvObject):
|
||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||
|
||||
|
||||
class Wall(Tile):
|
||||
class Wall(Floor):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
@ -302,7 +311,7 @@ class Door(Entity):
|
||||
@property
|
||||
def encoding(self):
|
||||
# This is important as it shadow is checked by occupation value
|
||||
return c.OCCUPIED_CELL if self.is_closed else 2
|
||||
return c.OCCUPIED_CELL if self.is_closed else 0.5
|
||||
|
||||
@property
|
||||
def str_state(self):
|
||||
@ -396,5 +405,5 @@ class Agent(MoveableEntity):
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(valid=bool(self.temp_action_result['valid']), action=str(self.temp_action_result['action']))
|
||||
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
||||
return state_dict
|
||||
|
@ -6,7 +6,7 @@ from typing import List, Union, Dict, Tuple
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
|
||||
from environments.factory.base.objects import Entity, Floor, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
|
||||
Object, EnvObject
|
||||
from environments.utility_classes import MovementProperties
|
||||
from environments import helpers as h
|
||||
@ -271,12 +271,9 @@ class GlobalPositions(EnvObjectRegister):
|
||||
|
||||
_accepted_objects = GlobalPosition
|
||||
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, is_blocking_light = False,
|
||||
can_be_shadowed = False, can_collide = False, **kwargs)
|
||||
|
||||
def as_array(self):
|
||||
# FIXME DEBUG!!! make this lazy?
|
||||
@ -377,7 +374,7 @@ class Entities(ObjectRegister):
|
||||
return found_entities
|
||||
|
||||
|
||||
class WallTiles(EntityRegister):
|
||||
class Walls(EntityRegister):
|
||||
_accepted_objects = Wall
|
||||
|
||||
def as_array(self):
|
||||
@ -390,9 +387,9 @@ class WallTiles(EntityRegister):
|
||||
return self._array
|
||||
|
||||
def __init__(self, *args, is_blocking_light=True, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
||||
can_collide=True,
|
||||
is_blocking_light=is_blocking_light, **kwargs)
|
||||
super(Walls, self).__init__(*args, individual_slices=False,
|
||||
can_collide=True,
|
||||
is_blocking_light=is_blocking_light, **kwargs)
|
||||
self._value = c.OCCUPIED_CELL
|
||||
|
||||
@classmethod
|
||||
@ -411,16 +408,16 @@ class WallTiles(EntityRegister):
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
if n_steps == h.STEPS_START:
|
||||
return super(WallTiles, self).summarize_states(n_steps=n_steps)
|
||||
return super(Walls, self).summarize_states(n_steps=n_steps)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
_accepted_objects = Tile
|
||||
class Floors(Walls):
|
||||
_accepted_objects = Floor
|
||||
|
||||
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||
super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||
self._value = c.FREE_CELL
|
||||
|
||||
@property
|
||||
@ -430,7 +427,7 @@ class FloorTiles(WallTiles):
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def empty_tiles(self) -> List[Tile]:
|
||||
def empty_tiles(self) -> List[Floor]:
|
||||
tiles = [tile for tile in self if tile.is_empty()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
@ -158,19 +158,19 @@ class BatteryFactory(BaseFactory):
|
||||
self.btry_prop = btry_prop
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
def additional_entities(self):
|
||||
super_entities = super().additional_entities
|
||||
def entities_hook(self):
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||
charge_pods = ChargePods.from_tiles(
|
||||
@ -185,8 +185,8 @@ class BatteryFactory(BaseFactory):
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).do_additional_step()
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
@ -230,7 +230,7 @@ class BatteryFactory(BaseFactory):
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
def reset_hook(self) -> None:
|
||||
# There is Nothing to reset.
|
||||
pass
|
||||
|
||||
@ -249,8 +249,8 @@ class BatteryFactory(BaseFactory):
|
||||
pass
|
||||
pass
|
||||
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent)
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
@ -260,9 +260,9 @@ class BatteryFactory(BaseFactory):
|
||||
pass
|
||||
return reward_event_dict
|
||||
|
||||
def render_additional_assets(self):
|
||||
def render_assets_hook(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
additional_assets = super().render_assets_hook()
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
||||
|
@ -147,17 +147,17 @@ class DestFactory(BaseFactory):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().additional_actions
|
||||
super_actions = super().actions_hook
|
||||
if self.dest_prop.dwell_time:
|
||||
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||
def entities_hook(self) -> Dict[(Enum, Entities)]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_entities = super().additional_entities
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
|
||||
destinations = Destinations.from_tiles(
|
||||
@ -194,9 +194,9 @@ class DestFactory(BaseFactory):
|
||||
else:
|
||||
return super_action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
def reset_hook(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super().do_additional_reset()
|
||||
super().reset_hook()
|
||||
self._dest_spawn_timer = dict()
|
||||
|
||||
def trigger_destination_spawn(self):
|
||||
@ -222,9 +222,9 @@ class DestFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_reward_info = super().do_additional_step()
|
||||
super_reward_info = super().step_hook()
|
||||
for key, val in self._dest_spawn_timer.items():
|
||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||
for dest in list(self[c.DEST].values()):
|
||||
@ -244,14 +244,14 @@ class DestFactory(BaseFactory):
|
||||
self.trigger_destination_spawn()
|
||||
return super_reward_info
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward_event_dict = super().additional_per_agent_reward(agent)
|
||||
reward_event_dict = super().per_agent_reward_hook(agent)
|
||||
if len(self[c.DEST_REACHED]):
|
||||
for reached_dest in list(self[c.DEST_REACHED]):
|
||||
if agent.pos == reached_dest.pos:
|
||||
@ -261,9 +261,9 @@ class DestFactory(BaseFactory):
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
|
||||
return reward_event_dict
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
def render_assets_hook(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
additional_assets = super().render_assets_hook()
|
||||
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
|
||||
additional_assets.extend(destinations)
|
||||
return additional_assets
|
||||
|
@ -1,5 +1,4 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
import random
|
||||
|
||||
@ -12,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Tile
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
@ -43,7 +42,6 @@ class DirtProperties(NamedTuple):
|
||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||
agent_can_interact: bool = True # Whether the agents can interact with the dirt in this environment.
|
||||
done_when_clean: bool = True
|
||||
|
||||
|
||||
@ -89,7 +87,7 @@ class DirtRegister(EntityRegister):
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
if isinstance(then_dirty_tiles, Tile):
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
@ -128,15 +126,14 @@ r = Rewards
|
||||
class DirtFactory(BaseFactory):
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
super_actions = super().additional_actions
|
||||
if self.dirt_prop.agent_can_interact:
|
||||
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
super_actions = super().actions_hook
|
||||
super_actions.append(Action(str_ident=a.CLEAN_UP))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(Enum, Entities)]:
|
||||
super_entities = super().additional_entities
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook
|
||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
||||
super_entities.update(({c.DIRT: dirt_register}))
|
||||
return super_entities
|
||||
@ -148,10 +145,11 @@ class DirtFactory(BaseFactory):
|
||||
self._dirt_rng = np.random.default_rng(env_seed)
|
||||
self._dirt: DirtRegister
|
||||
kwargs.update(env_seed=env_seed)
|
||||
# TODO: Reset ---> document this
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
additional_assets = super().render_additional_assets()
|
||||
def render_assets_hook(self, mode='human'):
|
||||
additional_assets = super().render_assets_hook()
|
||||
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
|
||||
for dirt in self[c.DIRT]]
|
||||
additional_assets.extend(dirt)
|
||||
@ -167,12 +165,12 @@ class DirtFactory(BaseFactory):
|
||||
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
|
||||
valid = c.VALID
|
||||
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1}
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
|
||||
reward = r.CLEAN_UP_VALID
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1}
|
||||
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
|
||||
reward = r.CLEAN_UP_FAIL
|
||||
|
||||
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
|
||||
@ -195,8 +193,8 @@ class DirtFactory(BaseFactory):
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
|
||||
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
super_reward_info = super().do_additional_step()
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super().step_hook()
|
||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||
for agent in self[c.AGENT]:
|
||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||
@ -229,8 +227,8 @@ class DirtFactory(BaseFactory):
|
||||
else:
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
super().do_additional_reset()
|
||||
def reset_hook(self) -> None:
|
||||
super().reset_hook()
|
||||
self.trigger_dirt_spawn(initial_spawn=True)
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
|
||||
|
||||
@ -242,13 +240,13 @@ class DirtFactory(BaseFactory):
|
||||
return all_cleaned, super_dict
|
||||
return super_done, super_dict
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def gather_additional_info(self, agent: Agent) -> dict:
|
||||
event_reward_dict = super().additional_per_agent_reward(agent)
|
||||
event_reward_dict = super().per_agent_reward_hook(agent)
|
||||
info_dict = dict()
|
||||
|
||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||
@ -280,8 +278,7 @@ if __name__ == '__main__':
|
||||
max_local_amount=1,
|
||||
spawn_frequency=0,
|
||||
max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0,
|
||||
agent_can_interact=True
|
||||
dirt_smear_amount=0.0
|
||||
)
|
||||
|
||||
obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True,
|
||||
@ -294,13 +291,13 @@ if __name__ == '__main__':
|
||||
global_timings = []
|
||||
for i in range(10):
|
||||
|
||||
factory = DirtFactory(n_agents=1, done_at_collision=False,
|
||||
factory = DirtFactory(n_agents=10, done_at_collision=False,
|
||||
level_name='rooms', max_steps=1000,
|
||||
doors_have_area=False,
|
||||
obs_prop=obs_props, parse_doors=True,
|
||||
verbose=True,
|
||||
mv_prop=move_props, dirt_prop=dirt_props,
|
||||
inject_agents=[TSPDirtAgent],
|
||||
# inject_agents=[TSPDirtAgent],
|
||||
)
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
@ -318,11 +315,11 @@ if __name__ == '__main__':
|
||||
env_state = factory.reset()
|
||||
if render:
|
||||
factory.render()
|
||||
tsp_agent = factory.get_injected_agents()[0]
|
||||
# tsp_agent = factory.get_injected_agents()[0]
|
||||
|
||||
rwrd = 0
|
||||
for agent_i_action in random_actions:
|
||||
agent_i_action = tsp_agent.predict()
|
||||
# agent_i_action = tsp_agent.predict()
|
||||
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
|
||||
rwrd += step_rwrd
|
||||
if render:
|
||||
|
58
environments/factory/factory_dirt_stationary_machines.py
Normal file
58
environments/factory/factory_dirt_stationary_machines.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import Floors, Entities, EntityRegister
|
||||
|
||||
|
||||
class Machines(EntityRegister):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Machine(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class StationaryMachinesDirtFactory(DirtFactory):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._machine_coords = [(6, 6), (12, 13)]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook()
|
||||
|
||||
return super_entities
|
||||
|
||||
def reset_hook(self) -> None:
|
||||
pass
|
||||
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
pass
|
||||
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
pass
|
||||
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
|
||||
pass
|
||||
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
return super_per_agent_raw_observations
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
pass
|
||||
|
||||
def pre_step_hook(self) -> None:
|
||||
pass
|
||||
|
||||
def post_step_hook(self) -> dict:
|
||||
pass
|
@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Tile
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
@ -25,7 +25,7 @@ class Constants(BaseConstants):
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'item_action'
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
@ -62,7 +62,7 @@ class ItemRegister(EntityRegister):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Tile]):
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.register_additional_items(items)
|
||||
|
||||
@ -193,16 +193,16 @@ class ItemFactory(BaseFactory):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def additional_actions(self) -> Union[Action, List[Action]]:
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().additional_actions
|
||||
super_actions = super().actions_hook
|
||||
super_actions.append(Action(str_ident=a.ITEM_ACTION))
|
||||
return super_actions
|
||||
|
||||
@property
|
||||
def additional_entities(self) -> Dict[(str, Entities)]:
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_entities = super().additional_entities
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
|
||||
drop_offs = DropOffLocations.from_tiles(
|
||||
@ -220,13 +220,13 @@ class ItemFactory(BaseFactory):
|
||||
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
|
||||
return super_entities
|
||||
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
|
||||
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
|
||||
return additional_observations
|
||||
@ -240,21 +240,21 @@ class ItemFactory(BaseFactory):
|
||||
valid = c.NOT_VALID
|
||||
if valid:
|
||||
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_VALID': 1}
|
||||
info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1}
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1}
|
||||
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
|
||||
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
item.change_register(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1}
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
|
||||
else:
|
||||
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1}
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
|
||||
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
|
||||
@ -269,9 +269,9 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
return action_result
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
def reset_hook(self) -> None:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super().do_additional_reset()
|
||||
super().reset_hook()
|
||||
self._next_item_spawn = self.item_prop.spawn_frequency
|
||||
self.trigger_item_spawn()
|
||||
|
||||
@ -284,9 +284,9 @@ class ItemFactory(BaseFactory):
|
||||
else:
|
||||
self.print('No Items are spawning, limit is reached.')
|
||||
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_reward_info = super().do_additional_step()
|
||||
super_reward_info = super().step_hook()
|
||||
for item in list(self[c.ITEM].values()):
|
||||
if item.auto_despawn >= 1:
|
||||
item.set_auto_despawn(item.auto_despawn-1)
|
||||
@ -301,9 +301,9 @@ class ItemFactory(BaseFactory):
|
||||
self._next_item_spawn = max(0, self._next_item_spawn-1)
|
||||
return super_reward_info
|
||||
|
||||
def render_additional_assets(self, mode='human'):
|
||||
def render_assets_hook(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
additional_assets = super().render_assets_hook()
|
||||
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
|
||||
additional_assets.extend(items)
|
||||
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
|
||||
@ -314,7 +314,7 @@ class ItemFactory(BaseFactory):
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
|
||||
|
||||
render = False
|
||||
render = True
|
||||
|
||||
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
|
||||
|
||||
@ -336,18 +336,18 @@ if __name__ == '__main__':
|
||||
obs_space = factory.observation_space
|
||||
obs_space_named = factory.named_observation_space
|
||||
|
||||
for epoch in range(4):
|
||||
for epoch in range(400):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps + 1)]
|
||||
env_state = factory.reset()
|
||||
r = 0
|
||||
rwrd = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
r += step_r
|
||||
rwrd += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
print(f'Factory run {epoch} done, reward is:\n {rwrd}')
|
||||
pass
|
||||
|
@ -1,5 +1,6 @@
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Union
|
||||
|
||||
@ -9,14 +10,17 @@ from environments.helpers import IGNORED_DF_COLUMNS
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from plotting.compare_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(BaseCallback):
|
||||
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env):
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
super(EnvMonitor, self).__init__()
|
||||
self.unwrapped = env
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dicts = defaultdict(dict)
|
||||
|
||||
@ -67,8 +71,10 @@ class EnvMonitor(BaseCallback):
|
||||
pass
|
||||
return
|
||||
|
||||
def save_run(self, filepath: Union[Path, str]):
|
||||
def save_run(self, filepath: Union[Path, str], auto_plotting_keys=None):
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if auto_plotting_keys:
|
||||
plot_single_run(filepath, column_keys=auto_plotting_keys)
|
||||
|
@ -24,14 +24,12 @@ class EnvRecorder(BaseCallback):
|
||||
self._entities = [entities]
|
||||
else:
|
||||
self._entities = entities
|
||||
self.started = False
|
||||
self.closed = False
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def reset(self):
|
||||
self.unwrapped._record_episodes = True
|
||||
self.unwrapped.start_recording()
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
@ -57,6 +55,14 @@ class EnvRecorder(BaseCallback):
|
||||
else:
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
step_result = self.unwrapped.step(actions)
|
||||
# 0, 1, 2 , 3 = idx
|
||||
# _, _, done_bool, info_obj = step_result
|
||||
self._read_info(0, step_result[3])
|
||||
self._read_done(0, step_result[2])
|
||||
return step_result
|
||||
|
||||
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -10,6 +10,45 @@ from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
||||
from plotting.plotting import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
@ -37,7 +76,10 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
if run_path.is_dir():
|
||||
prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import seaborn as sns
|
||||
import matplotlib as mpl
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
PALETTE = 10 * (
|
||||
@ -21,7 +22,14 @@ PALETTE = 10 * (
|
||||
def plot(filepath, ext='png'):
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
ax = plt.gca()
|
||||
legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)]
|
||||
|
||||
if legends:
|
||||
figure.savefig(str(filepath), format=ext, bbox_extra_artists=(*legends,), bbox_inches='tight')
|
||||
else:
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
|
||||
plt.show()
|
||||
plt.clf()
|
||||
|
||||
@ -30,7 +38,7 @@ def prepare_tex(df, hue, style, hue_order):
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
plt.tight_layout()
|
||||
return lineplot
|
||||
@ -48,6 +56,19 @@ def prepare_plt(df, hue, style, hue_order):
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
fig = plt.figure(figsize=(10, 11))
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43))
|
||||
plt.tight_layout()
|
||||
return lineplot
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
|
@ -4,7 +4,10 @@ from pathlib import Path
|
||||
import yaml
|
||||
from stable_baselines3 import A2C, PPO, DQN
|
||||
|
||||
from environments.factory.factory_dirt import Constants as c
|
||||
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
@ -16,32 +19,35 @@ if __name__ == '__main__':
|
||||
determin = False
|
||||
render = True
|
||||
record = False
|
||||
seed = 67
|
||||
verbose = True
|
||||
seed = 13
|
||||
n_agents = 1
|
||||
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||
out_path = Path('study_out/test/dirt')
|
||||
out_path = Path('study_out/reload')
|
||||
model_path = out_path
|
||||
|
||||
with (out_path / f'env_params.json').open('r') as f:
|
||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
env_kwargs.update(additional_agent_placeholder=None, n_agents=n_agents, max_steps=150)
|
||||
if gain_amount := env_kwargs.get('dirt_prop', {}).get('gain_amount', None):
|
||||
env_kwargs['dirt_prop']['max_spawn_amount'] = gain_amount
|
||||
del env_kwargs['dirt_prop']['gain_amount']
|
||||
|
||||
env_kwargs.update(record_episodes=record, done_at_collision=True)
|
||||
env_kwargs.update(n_agents=n_agents, done_at_collision=False, verbose=verbose)
|
||||
|
||||
this_model = out_path / 'model.zip'
|
||||
|
||||
model_cls = PPO # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
|
||||
models = [model_cls.load(this_model)]
|
||||
try:
|
||||
# Legacy Cleanups
|
||||
del env_kwargs['dirt_prop']['agent_can_interact']
|
||||
env_kwargs['verbose'] = True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Init Env
|
||||
with DirtFactory(**env_kwargs) as env:
|
||||
env = EnvRecorder(env)
|
||||
env = EnvMonitor(env)
|
||||
env = EnvRecorder(env) if record else env
|
||||
obs_shape = env.observation_space.shape
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(50):
|
||||
for episode in range(500):
|
||||
env_state = env.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
@ -55,7 +61,17 @@ if __name__ == '__main__':
|
||||
rew += step_r
|
||||
if render:
|
||||
env.render()
|
||||
try:
|
||||
door = next(x for x in env.unwrapped.unwrapped[c.DOORS] if x.is_open)
|
||||
print('openDoor found')
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
||||
env.save_run(out_path / 'reload_monitor.pick',
|
||||
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
||||
if record:
|
||||
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
|
||||
print('all done')
|
||||
|
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@ -65,8 +66,8 @@ def load_model_run_baseline(policy_path, env_to_run):
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
recorded_env_factory.save_run(filepath=policy_path / f'monitor.pick')
|
||||
recorded_env_factory.save_records(filepath=policy_path / f'recorder.json')
|
||||
recorded_env_factory.save_run(filepath=policy_path / f'baseline_monitor.pick')
|
||||
recorded_env_factory.save_records(filepath=policy_path / f'baseline_recorder.json')
|
||||
|
||||
|
||||
def load_model_run_combined(root_path, env_to_run, env_kwargs):
|
||||
@ -89,134 +90,156 @@ def load_model_run_combined(root_path, env_to_run, env_kwargs):
|
||||
env_factory.named_observation_space,
|
||||
*[x.named_observation_space for x in models])
|
||||
|
||||
monitored_env_factory = EnvMonitor(env_factory)
|
||||
recorded_env_factory = EnvRecorder(monitored_env_factory)
|
||||
env = EnvMonitor(env_factory)
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(5):
|
||||
env_state = recorded_env_factory.reset()
|
||||
env_state = env.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
translated_observations = observation_translator(env_state)
|
||||
actions = [model.predict(translated_observations[model_idx], deterministic=True)[0]
|
||||
for model_idx, model in enumerate(models)]
|
||||
translated_actions = action_translator(actions)
|
||||
env_state, step_r, done_bool, info_obj = recorded_env_factory.step(translated_actions)
|
||||
env_state, step_r, done_bool, info_obj = env.step(translated_actions)
|
||||
rew += step_r
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
recorded_env_factory.save_run(filepath=root_path / f'monitor.pick')
|
||||
recorded_env_factory.save_records(filepath=root_path / f'recorder.json')
|
||||
env.save_run(filepath=root_path / f'monitor_combined.pick')
|
||||
# env.save_records(filepath=root_path / f'recorder_combined.json')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# What to do:
|
||||
train = True
|
||||
individual_run = True
|
||||
individual_run = False
|
||||
combined_run = False
|
||||
multi_env = False
|
||||
|
||||
train_steps = 2e6
|
||||
train_steps = 1e6
|
||||
frames_to_stack = 3
|
||||
|
||||
# Define a global studi save path
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}'
|
||||
paremters_of_interest = dict(
|
||||
show_global_position_info=[True, False],
|
||||
pomdp_r=[3],
|
||||
cast_shadows=[True, False],
|
||||
allow_diagonal_movement=[True],
|
||||
parse_doors=[True, False],
|
||||
doors_have_area=[True, False],
|
||||
done_at_collision=[True, False]
|
||||
)
|
||||
keys, vals = zip(*paremters_of_interest.items())
|
||||
|
||||
def policy_model_kwargs():
|
||||
return dict()
|
||||
# Then we find all permutations for those values
|
||||
p = list(itertools.product(*vals))
|
||||
|
||||
# Define Global Env Parameters
|
||||
# Define properties object parameters
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=frames_to_stack,
|
||||
pomdp_r=2, cast_shadows=True)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=True,
|
||||
level_name='rooms', doors_have_area=True,
|
||||
verbose=False,
|
||||
mv_prop=move_props,
|
||||
obs_prop=obs_props,
|
||||
done_at_collision=False
|
||||
)
|
||||
# Finally we can create out list of dicts
|
||||
result = [{keys[index]: entry[index] for index in range(len(entry))} for entry in p]
|
||||
|
||||
# Bundle both environments with global kwargs and parameters
|
||||
env_map = {}
|
||||
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||
**factory_kwargs.copy()))})
|
||||
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||
# **factory_kwargs.copy()))})
|
||||
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
|
||||
item_prop=item_props,
|
||||
dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_names = list(env_map.keys())
|
||||
for u in result:
|
||||
file_name = '_'.join('_'.join([str(y)[0] for y in x]) for x in u.items())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / file_name
|
||||
|
||||
# Train starts here ############################################################
|
||||
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||
if train:
|
||||
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||
model_cls = h.MODEL_MAP['PPO']
|
||||
combination_path = study_root_path / env_key
|
||||
env_class, env_kwargs = env_map[env_key]
|
||||
# Model Kwargs
|
||||
policy_model_kwargs = dict(ent_coef=0.01)
|
||||
|
||||
# Output folder
|
||||
if (combination_path / 'monitor.pick').exists():
|
||||
continue
|
||||
combination_path.mkdir(parents=True, exist_ok=True)
|
||||
# Define Global Env Parameters
|
||||
# Define properties object parameters
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=frames_to_stack,
|
||||
pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'],
|
||||
show_global_position_info=u['show_global_position_info'])
|
||||
move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'],
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
|
||||
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'],
|
||||
level_name='rooms', doors_have_area=u['doors_have_area'],
|
||||
verbose=False,
|
||||
mv_prop=move_props,
|
||||
obs_prop=obs_props,
|
||||
done_at_collision=u['done_at_collision']
|
||||
)
|
||||
|
||||
if not multi_env:
|
||||
env_factory = encapsule_env_factory(env_class, env_kwargs)()
|
||||
else:
|
||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||
for _ in range(6)], start_method="spawn")
|
||||
# Bundle both environments with global kwargs and parameters
|
||||
env_map = {}
|
||||
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy()),
|
||||
['cleanup_valid', 'cleanup_fail'])})
|
||||
# env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
|
||||
# **factory_kwargs.copy()),
|
||||
# ['DROPOFF_FAIL', 'ITEMACTION_FAIL', 'DROPOFF_VALID', 'ITEMACTION_VALID'])})
|
||||
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
|
||||
# **factory_kwargs.copy()))})
|
||||
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
|
||||
item_prop=item_props,
|
||||
dirt_prop=dirt_props,
|
||||
**factory_kwargs.copy()))})
|
||||
env_names = list(env_map.keys())
|
||||
|
||||
param_path = combination_path / f'env_params.json'
|
||||
try:
|
||||
env_factory.env_method('save_params', param_path)
|
||||
except AttributeError:
|
||||
env_factory.save_params(param_path)
|
||||
# Train starts here ############################################################
|
||||
# Build Major Loop parameters, parameter versions, Env Classes and models
|
||||
if train:
|
||||
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
|
||||
model_cls = h.MODEL_MAP['PPO']
|
||||
combination_path = study_root_path / env_key
|
||||
env_class, env_kwargs, env_plot_keys = env_map[env_key]
|
||||
|
||||
# EnvMonitor Init
|
||||
callbacks = [EnvMonitor(env_factory)]
|
||||
# Output folder
|
||||
if (combination_path / 'monitor.pick').exists():
|
||||
continue
|
||||
combination_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Model Init
|
||||
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs(),
|
||||
verbose=1, seed=69, device='cpu')
|
||||
if not multi_env:
|
||||
env_factory = encapsule_env_factory(env_class, env_kwargs)()
|
||||
else:
|
||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||
for _ in range(6)], start_method="spawn")
|
||||
|
||||
# Model train
|
||||
model.learn(total_timesteps=int(train_steps), callback=callbacks)
|
||||
param_path = combination_path / f'env_params.json'
|
||||
try:
|
||||
env_factory.env_method('save_params', param_path)
|
||||
except AttributeError:
|
||||
env_factory.save_params(param_path)
|
||||
|
||||
# Model save
|
||||
try:
|
||||
model.named_action_space = env_factory.unwrapped.named_action_space
|
||||
model.named_observation_space = env_factory.unwrapped.named_observation_space
|
||||
except AttributeError:
|
||||
model.named_action_space = env_factory.get_attr("named_action_space")[0]
|
||||
model.named_observation_space = env_factory.get_attr("named_observation_space")[0]
|
||||
save_path = combination_path / f'model.zip'
|
||||
model.save(save_path)
|
||||
# EnvMonitor Init
|
||||
callbacks = [EnvMonitor(env_factory)]
|
||||
|
||||
# Monitor Save
|
||||
callbacks[0].save_run(combination_path / 'monitor.pick')
|
||||
# Model Init
|
||||
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
||||
verbose=1, seed=69, device='cpu')
|
||||
|
||||
# Better be save then sorry: Clean up!
|
||||
del env_factory, model
|
||||
import gc
|
||||
gc.collect()
|
||||
# Model train
|
||||
model.learn(total_timesteps=int(train_steps), callback=callbacks)
|
||||
|
||||
# Model save
|
||||
try:
|
||||
model.named_action_space = env_factory.unwrapped.named_action_space
|
||||
model.named_observation_space = env_factory.unwrapped.named_observation_space
|
||||
except AttributeError:
|
||||
model.named_action_space = env_factory.get_attr("named_action_space")[0]
|
||||
model.named_observation_space = env_factory.get_attr("named_observation_space")[0]
|
||||
save_path = combination_path / f'model.zip'
|
||||
model.save(save_path)
|
||||
|
||||
# Monitor Save
|
||||
callbacks[0].save_run(combination_path / 'monitor.pick',
|
||||
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
||||
|
||||
# Better be save then sorry: Clean up!
|
||||
del env_factory, model
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Train ends here ############################################################
|
||||
|
||||
|
Reference in New Issue
Block a user