Experiments look good

This commit is contained in:
Steffen Illium
2022-01-15 12:37:58 +01:00
parent d29ccbbb71
commit 823aa075b9
14 changed files with 478 additions and 297 deletions

View File

@ -15,8 +15,8 @@ from environments import helpers as h
from environments.helpers import Constants as c from environments.helpers import Constants as c
from environments.helpers import EnvActions as a from environments.helpers import EnvActions as a
from environments.helpers import Rewards as r from environments.helpers import Rewards as r
from environments.factory.base.objects import Agent, Tile, Action from environments.factory.base.objects import Agent, Floor, Action
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \ from environments.factory.base.registers import Actions, Entities, Agents, Doors, Floors, Walls, PlaceHolders, \
GlobalPositions GlobalPositions
from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack from environments.utility_classes import MovementProperties, ObservationProperties, MarlFrameStack
from environments.utility_classes import AgentRenderOptions as a_obs from environments.utility_classes import AgentRenderOptions as a_obs
@ -121,7 +121,7 @@ class BaseFactory(gym.Env):
self.doors_have_area = doors_have_area self.doors_have_area = doors_have_area
self.individual_rewards = individual_rewards self.individual_rewards = individual_rewards
# Reset # TODO: Reset ---> document this
self.reset() self.reset()
def __getitem__(self, item): def __getitem__(self, item):
@ -141,21 +141,21 @@ class BaseFactory(gym.Env):
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2 self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
# Walls # Walls
walls = WallTiles.from_argwhere_coordinates( walls = Walls.from_argwhere_coordinates(
np.argwhere(level_array == c.OCCUPIED_CELL), np.argwhere(level_array == c.OCCUPIED_CELL),
self._level_shape self._level_shape
) )
self._entities.register_additional_items({c.WALLS: walls}) self._entities.register_additional_items({c.WALLS: walls})
# Floor # Floor
floor = FloorTiles.from_argwhere_coordinates( floor = Floors.from_argwhere_coordinates(
np.argwhere(level_array == c.FREE_CELL), np.argwhere(level_array == c.FREE_CELL),
self._level_shape self._level_shape
) )
self._entities.register_additional_items({c.FLOOR: floor}) self._entities.register_additional_items({c.FLOOR: floor})
# NOPOS # NOPOS
self._NO_POS_TILE = Tile(c.NO_POS, None) self._NO_POS_TILE = Floor(c.NO_POS, None)
# Doors # Doors
if self.parse_doors: if self.parse_doors:
@ -170,7 +170,7 @@ class BaseFactory(gym.Env):
# Actions # Actions
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors) self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
if additional_actions := self.additional_actions: if additional_actions := self.actions_hook:
self._actions.register_additional_items(additional_actions) self._actions.register_additional_items(additional_actions)
# Agents # Agents
@ -202,7 +202,7 @@ class BaseFactory(gym.Env):
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder}) self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
# Additional Entitites from SubEnvs # Additional Entitites from SubEnvs
if additional_entities := self.additional_entities: if additional_entities := self.entities_hook:
self._entities.register_additional_items(additional_entities) self._entities.register_additional_items(additional_entities)
if self.obs_prop.show_global_position_info: if self.obs_prop.show_global_position_info:
@ -217,7 +217,7 @@ class BaseFactory(gym.Env):
def reset(self) -> (np.typing.ArrayLike, int, bool, dict): def reset(self) -> (np.typing.ArrayLike, int, bool, dict):
_ = self._base_init_env() _ = self._base_init_env()
self.do_additional_reset() self.reset_hook()
self._steps = 0 self._steps = 0
@ -233,7 +233,7 @@ class BaseFactory(gym.Env):
self._steps += 1 self._steps += 1
# Pre step Hook for later use # Pre step Hook for later use
self.hook_pre_step() self.pre_step_hook()
for action, agent in zip(actions, self[c.AGENT]): for action, agent in zip(actions, self[c.AGENT]):
agent.clear_temp_state() agent.clear_temp_state()
@ -244,7 +244,7 @@ class BaseFactory(gym.Env):
action_valid, reward = self._do_move_action(agent, action_obj) action_valid, reward = self._do_move_action(agent, action_obj)
elif a.NOOP == action_obj: elif a.NOOP == action_obj:
action_valid = c.VALID action_valid = c.VALID
reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.pos}_NOOP': 1}) reward = dict(value=r.NOOP, reason=a.NOOP, info={f'{agent.name}_NOOP': 1, 'NOOP': 1})
elif a.USE_DOOR == action_obj: elif a.USE_DOOR == action_obj:
action_valid, reward = self._handle_door_interaction(agent) action_valid, reward = self._handle_door_interaction(agent)
else: else:
@ -258,7 +258,7 @@ class BaseFactory(gym.Env):
agent.step_result = step_result agent.step_result = step_result
# Additional step and Reward, Info Init # Additional step and Reward, Info Init
rewards, info = self.do_additional_step() rewards, info = self.step_hook()
# Todo: Make this faster, so that only tiles of entities that can collide are searched. # Todo: Make this faster, so that only tiles of entities that can collide are searched.
tiles_with_collisions = self.get_all_tiles_with_collisions() tiles_with_collisions = self.get_all_tiles_with_collisions()
for tile in tiles_with_collisions: for tile in tiles_with_collisions:
@ -297,7 +297,7 @@ class BaseFactory(gym.Env):
info.update(self._summarize_state()) info.update(self._summarize_state())
# Post step Hook for later use # Post step Hook for later use
info.update(self.hook_post_step()) info.update(self.post_step_hook())
obs, _ = self._build_observations() obs, _ = self._build_observations()
@ -314,11 +314,11 @@ class BaseFactory(gym.Env):
door.use() door.use()
valid = c.VALID valid = c.VALID
self.print(f'{agent.name} just used a {door.name} at {door.pos}') self.print(f'{agent.name} just used a {door.name} at {door.pos}')
info_dict = {f'{agent.name}_door_use': 1} info_dict = {f'{agent.name}_door_use': 1, f'door_use': 1}
# When he doesn't... # When he doesn't...
else: else:
valid = c.NOT_VALID valid = c.NOT_VALID
info_dict = {f'{agent.name}_failed_door_use': 1} info_dict = {f'{agent.name}_failed_door_use': 1, 'failed_door_use': 1}
self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.') self.print(f'{agent.name} just tried to use a door at {agent.pos}, but there is none.')
else: else:
@ -334,7 +334,7 @@ class BaseFactory(gym.Env):
per_agent_obsn = dict() per_agent_obsn = dict()
# Generel Observations # Generel Observations
lvl_obs = self[c.WALLS].as_array() lvl_obs = self[c.WALLS].as_array()
door_obs = self[c.DOORS].as_array() door_obs = self[c.DOORS].as_array() if self.parse_doors else None
if self.obs_prop.render_agents == a_obs.NOT: if self.obs_prop.render_agents == a_obs.NOT:
global_agent_obs = None global_agent_obs = None
elif self.obs_prop.omit_agent_self and self.n_agents == 1: elif self.obs_prop.omit_agent_self and self.n_agents == 1:
@ -342,7 +342,7 @@ class BaseFactory(gym.Env):
else: else:
global_agent_obs = self[c.AGENT].as_array().copy() global_agent_obs = self[c.AGENT].as_array().copy()
placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None placeholder_obs = self[c.AGENT_PLACEHOLDER].as_array() if self[c.AGENT_PLACEHOLDER] else None
add_obs_dict = self._additional_observations() add_obs_dict = self.observations_hook()
for agent_idx, agent in enumerate(self[c.AGENT]): for agent_idx, agent in enumerate(self[c.AGENT]):
obs_dict = dict() obs_dict = dict()
@ -367,17 +367,17 @@ class BaseFactory(gym.Env):
obs_dict[c.WALLS] = lvl_obs obs_dict[c.WALLS] = lvl_obs
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None: if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
obs_dict[c.AGENT] = agent_obs obs_dict[c.AGENT] = agent_obs[:]
if self[c.AGENT_PLACEHOLDER] and placeholder_obs is not None: if self[c.AGENT_PLACEHOLDER] and placeholder_obs is not None:
obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs
if self.parse_doors and door_obs is not None: if self.parse_doors and door_obs is not None:
obs_dict[c.DOORS] = door_obs obs_dict[c.DOORS] = door_obs[:]
obs_dict.update(add_obs_dict) obs_dict.update(add_obs_dict)
obsn = np.vstack(list(obs_dict.values())) obsn = np.vstack(list(obs_dict.values()))
if self.obs_prop.pomdp_r: if self.obs_prop.pomdp_r:
obsn = self._do_pomdp_cutout(agent, obsn) obsn = self._do_pomdp_cutout(agent, obsn)
raw_obs = self._additional_per_agent_raw_observations(agent) raw_obs = self.per_agent_raw_observations_hook(agent)
raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()} raw_obs = {key: np.expand_dims(val, 0) if val.ndim != 3 else val for key, val in raw_obs.items()}
obsn = np.vstack((obsn, *raw_obs.values())) obsn = np.vstack((obsn, *raw_obs.values()))
@ -387,6 +387,12 @@ class BaseFactory(gym.Env):
zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])} zip(keys, idxs, list(idxs[1:]) + [idxs[-1]+1, ])}
# Shadow Casting # Shadow Casting
if agent.step_result is not None:
pass
else:
assert self._steps == 0
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
'collisions': [], 'lightmap': None}
if self.obs_prop.cast_shadows: if self.obs_prop.cast_shadows:
try: try:
light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items() light_block_obs = [obs_idx for key, obs_idx in per_agent_expl_idx[agent.name].items()
@ -430,17 +436,15 @@ class BaseFactory(gym.Env):
if door_shadowing: if door_shadowing:
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
light_block_map[xs, ys] = 0 light_block_map[xs, ys] = 0
if agent.step_result:
agent.step_result['lightmap'] = light_block_map agent.step_result['lightmap'] = light_block_map
pass
else:
assert self._steps == 0
agent.step_result = {'action_name': a.NOOP, 'action_valid': True,
'collisions': [], 'lightmap': light_block_map}
obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map) obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map)
else: else:
pass if self._pomdp_r:
agent.step_result['lightmap'] = np.ones(self._obs_shape)
else:
agent.step_result['lightmap'] = None
per_agent_obsn[agent.name] = obsn per_agent_obsn[agent.name] = obsn
@ -484,7 +488,7 @@ class BaseFactory(gym.Env):
oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant') oobs = np.pad(oobs, ((0, 0), (x0_pad, x1_pad), (y0_pad, y1_pad)), 'constant')
return oobs return oobs
def get_all_tiles_with_collisions(self) -> List[Tile]: def get_all_tiles_with_collisions(self) -> List[Floor]:
tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1] tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
if False: if False:
tiles_with_collisions = list() tiles_with_collisions = list()
@ -503,22 +507,22 @@ class BaseFactory(gym.Env):
valid = agent.move(new_tile) valid = agent.move(new_tile)
if valid: if valid:
# This will spam your logs, beware! # This will spam your logs, beware!
# self.print(f'{agent.name} just moved from {agent.last_pos} to {agent.pos}.') self.print(f'{agent.name} just moved {action.identifier} from {agent.last_pos} to {agent.pos}.')
# info_dict.update({f'{agent.pos}_move': 1}) info_dict.update({f'{agent.name}_move': 1, 'move': 1})
pass pass
else: else:
valid = c.NOT_VALID valid = c.NOT_VALID
self.print(f'{agent.name} just hit the wall at {agent.pos}.') self.print(f'{agent.name} just hit the wall at {agent.pos}. ({action.identifier})')
info_dict.update({f'{agent.name}_wall_collide': 1}) info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
else: else:
# Agent seems to be trying to Leave the level # Agent seems to be trying to Leave the level
self.print(f'{agent.name} tried to leave the level {agent.pos}.') self.print(f'{agent.name} tried to leave the level {agent.pos}. ({action.identifier})')
info_dict.update({f'{agent.name}_wall_collide': 1}) info_dict.update({f'{agent.name}_wall_collide': 1, 'wall_collide': 1})
reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL reward_value = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict} reward = {'value': reward_value, 'reason': action.identifier, 'info': info_dict}
return valid, reward return valid, reward
def _check_agent_move(self, agent, action: Action) -> (Tile, bool): def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
# Actions # Actions
x_diff, y_diff = h.ACTIONMAP[action.identifier] x_diff, y_diff = h.ACTIONMAP[action.identifier]
x_new = agent.x + x_diff x_new = agent.x + x_diff
@ -556,10 +560,6 @@ class BaseFactory(gym.Env):
return new_tile, valid return new_tile, valid
@abc.abstractmethod
def additional_per_agent_rewards(self, agent) -> List[dict]:
return []
def build_reward_result(self, global_env_rewards: list) -> (int, dict): def build_reward_result(self, global_env_rewards: list) -> (int, dict):
# Returns: Reward, Info # Returns: Reward, Info
info = defaultdict(lambda: 0.0) info = defaultdict(lambda: 0.0)
@ -567,7 +567,7 @@ class BaseFactory(gym.Env):
# Gather additional sub-env rewards and calculate collisions # Gather additional sub-env rewards and calculate collisions
for agent in self[c.AGENT]: for agent in self[c.AGENT]:
rewards = self.additional_per_agent_rewards(agent) rewards = self.per_agent_reward_hook(agent)
for reward in rewards: for reward in rewards:
agent.step_result['rewards'].append(reward) agent.step_result['rewards'].append(reward)
if collisions := agent.step_result['collisions']: if collisions := agent.step_result['collisions']:
@ -601,6 +601,12 @@ class BaseFactory(gym.Env):
self.print(f"reward is {reward}") self.print(f"reward is {reward}")
return reward, combined_info_dict return reward, combined_info_dict
def start_recording(self):
self._record_episodes = True
def stop_recording(self):
self._record_episodes = False
# noinspection PyGlobalUndefined # noinspection PyGlobalUndefined
def render(self, mode='human'): def render(self, mode='human'):
if not self._renderer: # lazy init if not self._renderer: # lazy init
@ -621,7 +627,7 @@ class BaseFactory(gym.Env):
for i, door in enumerate(self[c.DOORS]): for i, door in enumerate(self[c.DOORS]):
name, state = 'door_open' if door.is_open else 'door_closed', 'blank' name, state = 'door_open' if door.is_open else 'door_closed', 'blank'
doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1)) doors.append(RenderEntity(name, door.pos, 1, 'none', state, i + 1))
additional_assets = self.render_additional_assets() additional_assets = self.render_assets_hook()
return self._renderer.render(walls + doors + additional_assets + agents) return self._renderer.render(walls + doors + additional_assets + agents)
@ -652,7 +658,8 @@ class BaseFactory(gym.Env):
# Properties which are called by the base class to extend beyond attributes of the base class # Properties which are called by the base class to extend beyond attributes of the base class
@property @property
def additional_actions(self) -> Union[Action, List[Action]]: @abc.abstractmethod
def actions_hook(self) -> Union[Action, List[Action]]:
""" """
When heriting from this Base Class, you musst implement this methode!!! When heriting from this Base Class, you musst implement this methode!!!
@ -662,7 +669,8 @@ class BaseFactory(gym.Env):
return [] return []
@property @property
def additional_entities(self) -> Dict[(str, Entities)]: @abc.abstractmethod
def entities_hook(self) -> Dict[(str, Entities)]:
""" """
When heriting from this Base Class, you musst implement this methode!!! When heriting from this Base Class, you musst implement this methode!!!
@ -674,27 +682,39 @@ class BaseFactory(gym.Env):
# Functions which provide additions to functions of the base class # Functions which provide additions to functions of the base class
# Always call super!!!!!! # Always call super!!!!!!
@abc.abstractmethod @abc.abstractmethod
def do_additional_reset(self) -> None: def reset_hook(self) -> None:
pass pass
@abc.abstractmethod @abc.abstractmethod
def do_additional_step(self) -> (List[dict], dict): def pre_step_hook(self) -> None:
return [], {} pass
@abc.abstractmethod @abc.abstractmethod
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict): def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
return None return None
@abc.abstractmethod
def step_hook(self) -> (List[dict], dict):
return [], {}
@abc.abstractmethod @abc.abstractmethod
def check_additional_done(self) -> (bool, dict): def check_additional_done(self) -> (bool, dict):
return False, {} return False, {}
@abc.abstractmethod @abc.abstractmethod
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
return {} return {}
@abc.abstractmethod @abc.abstractmethod
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
return {}
@abc.abstractmethod
def post_step_hook(self) -> dict:
return {}
@abc.abstractmethod
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = {} additional_raw_observations = {}
if self.obs_prop.show_global_position_info: if self.obs_prop.show_global_position_info:
global_pos_obs = np.zeros(self._obs_shape) global_pos_obs = np.zeros(self._obs_shape)
@ -703,19 +723,5 @@ class BaseFactory(gym.Env):
return additional_raw_observations return additional_raw_observations
@abc.abstractmethod @abc.abstractmethod
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: def render_assets_hook(self):
return {}
@abc.abstractmethod
def render_additional_assets(self):
return [] return []
# Hooks for in between operations.
# Always call super!!!!!!
@abc.abstractmethod
def hook_pre_step(self) -> None:
pass
@abc.abstractmethod
def hook_post_step(self) -> dict:
return {}

View File

@ -9,10 +9,11 @@ from environments.helpers import Constants as c
import itertools import itertools
########################################################################## ##########################################################################
# ##################### Base Object Definition ######################### # # ##################### Base Object Building Blocks ######################### #
########################################################################## ##########################################################################
# TODO: Missing Documentation
class Object: class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc...""" """Generell Objects for Organisation and Maintanance such as Actions etc..."""
@ -53,8 +54,10 @@ class Object:
def __eq__(self, other) -> bool: def __eq__(self, other) -> bool:
return other == self.identifier return other == self.identifier
# Base
# TODO: Missing Documentation
class EnvObject(Object): class EnvObject(Object):
"""Objects that hold Information that are observable, but have no position on the env grid. Inventories etc...""" """Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
@ -78,27 +81,10 @@ class EnvObject(Object):
self._register.delete_env_object(self) self._register.delete_env_object(self)
self._register = register self._register = register
return self._register == register return self._register == register
# With Rendering
class BoundingMixin(Object): # TODO: Missing Documentation
@property
def bound_entity(self):
return self._bound_entity
def __init__(self,entity_to_be_bound, *args, **kwargs):
super(BoundingMixin, self).__init__(*args, **kwargs)
assert entity_to_be_bound is not None
self._bound_entity = entity_to_be_bound
@property
def name(self):
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
def belongs_to_entity(self, entity):
return entity == self.bound_entity
class Entity(EnvObject): class Entity(EnvObject):
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc...""" """Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
@ -133,8 +119,10 @@ class Entity(EnvObject):
def __repr__(self): def __repr__(self):
return super(Entity, self).__repr__() + f'(@{self.pos})' return super(Entity, self).__repr__() + f'(@{self.pos})'
# With Position in Env
# TODO: Missing Documentation
class MoveableEntity(Entity): class MoveableEntity(Entity):
@property @property
@ -169,6 +157,27 @@ class MoveableEntity(Entity):
return c.VALID return c.VALID
else: else:
return c.NOT_VALID return c.NOT_VALID
# Can Move
# TODO: Missing Documentation
class BoundingMixin(Object):
@property
def bound_entity(self):
return self._bound_entity
def __init__(self,entity_to_be_bound, *args, **kwargs):
super(BoundingMixin, self).__init__(*args, **kwargs)
assert entity_to_be_bound is not None
self._bound_entity = entity_to_be_bound
@property
def name(self):
return f'{super(BoundingMixin, self).name}({self._bound_entity.name})'
def belongs_to_entity(self, entity):
return entity == self.bound_entity
########################################################################## ##########################################################################
@ -216,7 +225,7 @@ class GlobalPosition(BoundingMixin, EnvObject):
self._normalized = normalized self._normalized = normalized
class Tile(EnvObject): class Floor(EnvObject):
@property @property
def encoding(self): def encoding(self):
@ -243,7 +252,7 @@ class Tile(EnvObject):
return self._pos return self._pos
def __init__(self, pos, *args, **kwargs): def __init__(self, pos, *args, **kwargs):
super(Tile, self).__init__(*args, **kwargs) super(Floor, self).__init__(*args, **kwargs)
self._guests = dict() self._guests = dict()
self._pos = tuple(pos) self._pos = tuple(pos)
@ -277,7 +286,7 @@ class Tile(EnvObject):
return dict(name=self.name, x=int(self.x), y=int(self.y)) return dict(name=self.name, x=int(self.x), y=int(self.y))
class Wall(Tile): class Wall(Floor):
@property @property
def can_collide(self): def can_collide(self):
@ -302,7 +311,7 @@ class Door(Entity):
@property @property
def encoding(self): def encoding(self):
# This is important as it shadow is checked by occupation value # This is important as it shadow is checked by occupation value
return c.OCCUPIED_CELL if self.is_closed else 2 return c.OCCUPIED_CELL if self.is_closed else 0.5
@property @property
def str_state(self): def str_state(self):
@ -396,5 +405,5 @@ class Agent(MoveableEntity):
def summarize_state(self, **kwargs): def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs) state_dict = super().summarize_state(**kwargs)
state_dict.update(valid=bool(self.temp_action_result['valid']), action=str(self.temp_action_result['action'])) state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
return state_dict return state_dict

View File

@ -6,7 +6,7 @@ from typing import List, Union, Dict, Tuple
import numpy as np import numpy as np
import six import six
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \ from environments.factory.base.objects import Entity, Floor, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
Object, EnvObject Object, EnvObject
from environments.utility_classes import MovementProperties from environments.utility_classes import MovementProperties
from environments import helpers as h from environments import helpers as h
@ -271,12 +271,9 @@ class GlobalPositions(EnvObjectRegister):
_accepted_objects = GlobalPosition _accepted_objects = GlobalPosition
is_blocking_light = False
can_be_shadowed = False
can_collide = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, is_blocking_light = False,
can_be_shadowed = False, can_collide = False, **kwargs)
def as_array(self): def as_array(self):
# FIXME DEBUG!!! make this lazy? # FIXME DEBUG!!! make this lazy?
@ -377,7 +374,7 @@ class Entities(ObjectRegister):
return found_entities return found_entities
class WallTiles(EntityRegister): class Walls(EntityRegister):
_accepted_objects = Wall _accepted_objects = Wall
def as_array(self): def as_array(self):
@ -390,9 +387,9 @@ class WallTiles(EntityRegister):
return self._array return self._array
def __init__(self, *args, is_blocking_light=True, **kwargs): def __init__(self, *args, is_blocking_light=True, **kwargs):
super(WallTiles, self).__init__(*args, individual_slices=False, super(Walls, self).__init__(*args, individual_slices=False,
can_collide=True, can_collide=True,
is_blocking_light=is_blocking_light, **kwargs) is_blocking_light=is_blocking_light, **kwargs)
self._value = c.OCCUPIED_CELL self._value = c.OCCUPIED_CELL
@classmethod @classmethod
@ -411,16 +408,16 @@ class WallTiles(EntityRegister):
def summarize_states(self, n_steps=None): def summarize_states(self, n_steps=None):
if n_steps == h.STEPS_START: if n_steps == h.STEPS_START:
return super(WallTiles, self).summarize_states(n_steps=n_steps) return super(Walls, self).summarize_states(n_steps=n_steps)
else: else:
return {} return {}
class FloorTiles(WallTiles): class Floors(Walls):
_accepted_objects = Tile _accepted_objects = Floor
def __init__(self, *args, is_blocking_light=False, **kwargs): def __init__(self, *args, is_blocking_light=False, **kwargs):
super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs) super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
self._value = c.FREE_CELL self._value = c.FREE_CELL
@property @property
@ -430,7 +427,7 @@ class FloorTiles(WallTiles):
return tiles return tiles
@property @property
def empty_tiles(self) -> List[Tile]: def empty_tiles(self) -> List[Floor]:
tiles = [tile for tile in self if tile.is_empty()] tiles = [tile for tile in self if tile.is_empty()]
random.shuffle(tiles) random.shuffle(tiles)
return tiles return tiles

View File

@ -158,19 +158,19 @@ class BatteryFactory(BaseFactory):
self.btry_prop = btry_prop self.btry_prop = btry_prop
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_per_agent_raw_observations(agent) additional_raw_observations = super().per_agent_raw_observations_hook(agent)
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)}) additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
return additional_raw_observations return additional_raw_observations
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations() additional_observations = super().observations_hook()
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()}) additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
return additional_observations return additional_observations
@property @property
def additional_entities(self): def entities_hook(self):
super_entities = super().additional_entities super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations] empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
charge_pods = ChargePods.from_tiles( charge_pods = ChargePods.from_tiles(
@ -185,8 +185,8 @@ class BatteryFactory(BaseFactory):
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods}) super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
return super_entities return super_entities
def do_additional_step(self) -> (List[dict], dict): def step_hook(self) -> (List[dict], dict):
super_reward_info = super(BatteryFactory, self).do_additional_step() super_reward_info = super(BatteryFactory, self).step_hook()
# Decharge # Decharge
batteries = self[c.BATTERIES] batteries = self[c.BATTERIES]
@ -230,7 +230,7 @@ class BatteryFactory(BaseFactory):
return action_result return action_result
pass pass
def do_additional_reset(self) -> None: def reset_hook(self) -> None:
# There is Nothing to reset. # There is Nothing to reset.
pass pass
@ -249,8 +249,8 @@ class BatteryFactory(BaseFactory):
pass pass
pass pass
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent) reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
if self[c.BATTERIES].by_entity(agent).is_discharged: if self[c.BATTERIES].by_entity(agent).is_discharged:
self.print(f'{agent.name} Battery is discharged!') self.print(f'{agent.name} Battery is discharged!')
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1} info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
@ -260,9 +260,9 @@ class BatteryFactory(BaseFactory):
pass pass
return reward_event_dict return reward_event_dict
def render_additional_assets(self): def render_assets_hook(self):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets() additional_assets = super().render_assets_hook()
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]] charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
additional_assets.extend(charge_pods) additional_assets.extend(charge_pods)
return additional_assets return additional_assets

View File

@ -147,17 +147,17 @@ class DestFactory(BaseFactory):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@property @property
def additional_actions(self) -> Union[Action, List[Action]]: def actions_hook(self) -> Union[Action, List[Action]]:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super_actions = super().additional_actions super_actions = super().actions_hook
if self.dest_prop.dwell_time: if self.dest_prop.dwell_time:
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST)) super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
return super_actions return super_actions
@property @property
def additional_entities(self) -> Dict[(Enum, Entities)]: def entities_hook(self) -> Dict[(Enum, Entities)]:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super_entities = super().additional_entities super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests] empty_tiles = self[c.FLOOR].empty_tiles[:self.dest_prop.n_dests]
destinations = Destinations.from_tiles( destinations = Destinations.from_tiles(
@ -194,9 +194,9 @@ class DestFactory(BaseFactory):
else: else:
return super_action_result return super_action_result
def do_additional_reset(self) -> None: def reset_hook(self) -> None:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super().do_additional_reset() super().reset_hook()
self._dest_spawn_timer = dict() self._dest_spawn_timer = dict()
def trigger_destination_spawn(self): def trigger_destination_spawn(self):
@ -222,9 +222,9 @@ class DestFactory(BaseFactory):
else: else:
self.print('No Items are spawning, limit is reached.') self.print('No Items are spawning, limit is reached.')
def do_additional_step(self) -> (List[dict], dict): def step_hook(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super_reward_info = super().do_additional_step() super_reward_info = super().step_hook()
for key, val in self._dest_spawn_timer.items(): for key, val in self._dest_spawn_timer.items():
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DEST].values()): for dest in list(self[c.DEST].values()):
@ -244,14 +244,14 @@ class DestFactory(BaseFactory):
self.trigger_destination_spawn() self.trigger_destination_spawn()
return super_reward_info return super_reward_info
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations() additional_observations = super().observations_hook()
additional_observations.update({c.DEST: self[c.DEST].as_array()}) additional_observations.update({c.DEST: self[c.DEST].as_array()})
return additional_observations return additional_observations
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]: def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
reward_event_dict = super().additional_per_agent_reward(agent) reward_event_dict = super().per_agent_reward_hook(agent)
if len(self[c.DEST_REACHED]): if len(self[c.DEST_REACHED]):
for reached_dest in list(self[c.DEST_REACHED]): for reached_dest in list(self[c.DEST_REACHED]):
if agent.pos == reached_dest.pos: if agent.pos == reached_dest.pos:
@ -261,9 +261,9 @@ class DestFactory(BaseFactory):
reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}}) reward_event_dict.update({c.DEST_REACHED: {'reward': r.DEST_REACHED, 'info': info_dict}})
return reward_event_dict return reward_event_dict
def render_additional_assets(self, mode='human'): def render_assets_hook(self, mode='human'):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets() additional_assets = super().render_assets_hook()
destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]] destinations = [RenderEntity(c.DEST, dest.pos) for dest in self[c.DEST]]
additional_assets.extend(destinations) additional_assets.extend(destinations)
return additional_assets return additional_assets

View File

@ -1,5 +1,4 @@
import time import time
from enum import Enum
from typing import List, Union, NamedTuple, Dict from typing import List, Union, NamedTuple, Dict
import random import random
@ -12,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards from environments.helpers import Rewards as BaseRewards
from environments.factory.base.base_factory import BaseFactory from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Tile from environments.factory.base.objects import Agent, Action, Entity, Floor
from environments.factory.base.registers import Entities, EntityRegister from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
@ -43,7 +42,6 @@ class DirtProperties(NamedTuple):
max_local_amount: int = 2 # Max dirt amount per tile. max_local_amount: int = 2 # Max dirt amount per tile.
max_global_amount: int = 20 # Max dirt amount in the whole environment. max_global_amount: int = 20 # Max dirt amount in the whole environment.
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place. dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
agent_can_interact: bool = True # Whether the agents can interact with the dirt in this environment.
done_when_clean: bool = True done_when_clean: bool = True
@ -89,7 +87,7 @@ class DirtRegister(EntityRegister):
self._dirt_properties: DirtProperties = dirt_properties self._dirt_properties: DirtProperties = dirt_properties
def spawn_dirt(self, then_dirty_tiles) -> bool: def spawn_dirt(self, then_dirty_tiles) -> bool:
if isinstance(then_dirty_tiles, Tile): if isinstance(then_dirty_tiles, Floor):
then_dirty_tiles = [then_dirty_tiles] then_dirty_tiles = [then_dirty_tiles]
for tile in then_dirty_tiles: for tile in then_dirty_tiles:
if not self.amount > self.dirt_properties.max_global_amount: if not self.amount > self.dirt_properties.max_global_amount:
@ -128,15 +126,14 @@ r = Rewards
class DirtFactory(BaseFactory): class DirtFactory(BaseFactory):
@property @property
def additional_actions(self) -> Union[Action, List[Action]]: def actions_hook(self) -> Union[Action, List[Action]]:
super_actions = super().additional_actions super_actions = super().actions_hook
if self.dirt_prop.agent_can_interact: super_actions.append(Action(str_ident=a.CLEAN_UP))
super_actions.append(Action(str_ident=a.CLEAN_UP))
return super_actions return super_actions
@property @property
def additional_entities(self) -> Dict[(Enum, Entities)]: def entities_hook(self) -> Dict[(str, Entities)]:
super_entities = super().additional_entities super_entities = super().entities_hook
dirt_register = DirtRegister(self.dirt_prop, self._level_shape) dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
super_entities.update(({c.DIRT: dirt_register})) super_entities.update(({c.DIRT: dirt_register}))
return super_entities return super_entities
@ -148,10 +145,11 @@ class DirtFactory(BaseFactory):
self._dirt_rng = np.random.default_rng(env_seed) self._dirt_rng = np.random.default_rng(env_seed)
self._dirt: DirtRegister self._dirt: DirtRegister
kwargs.update(env_seed=env_seed) kwargs.update(env_seed=env_seed)
# TODO: Reset ---> document this
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def render_additional_assets(self, mode='human'): def render_assets_hook(self, mode='human'):
additional_assets = super().render_additional_assets() additional_assets = super().render_assets_hook()
dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale') dirt = [RenderEntity('dirt', dirt.tile.pos, min(0.15 + dirt.amount, 1.5), 'scale')
for dirt in self[c.DIRT]] for dirt in self[c.DIRT]]
additional_assets.extend(dirt) additional_assets.extend(dirt)
@ -167,12 +165,12 @@ class DirtFactory(BaseFactory):
dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value)) dirt.set_new_amount(max(new_dirt_amount, c.FREE_CELL.value))
valid = c.VALID valid = c.VALID
self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.') self.print(f'{agent.name} did just clean up some dirt at {agent.pos}.')
info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1} info_dict = {f'{agent.name}_{a.CLEAN_UP}_VALID': 1, 'cleanup_valid': 1}
reward = r.CLEAN_UP_VALID reward = r.CLEAN_UP_VALID
else: else:
valid = c.NOT_VALID valid = c.NOT_VALID
self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.') self.print(f'{agent.name} just tried to clean up some dirt at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1} info_dict = {f'{agent.name}_{a.CLEAN_UP}_FAIL': 1, 'cleanup_fail': 1}
reward = r.CLEAN_UP_FAIL reward = r.CLEAN_UP_FAIL
if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0): if valid and self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0):
@ -195,8 +193,8 @@ class DirtFactory(BaseFactory):
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles]) self[c.DIRT].spawn_dirt(free_for_dirt[:n_dirt_tiles])
def do_additional_step(self) -> (List[dict], dict): def step_hook(self) -> (List[dict], dict):
super_reward_info = super().do_additional_step() super_reward_info = super().step_hook()
if smear_amount := self.dirt_prop.dirt_smear_amount: if smear_amount := self.dirt_prop.dirt_smear_amount:
for agent in self[c.AGENT]: for agent in self[c.AGENT]:
if agent.temp_valid and agent.last_pos != c.NO_POS: if agent.temp_valid and agent.last_pos != c.NO_POS:
@ -229,8 +227,8 @@ class DirtFactory(BaseFactory):
else: else:
return action_result return action_result
def do_additional_reset(self) -> None: def reset_hook(self) -> None:
super().do_additional_reset() super().reset_hook()
self.trigger_dirt_spawn(initial_spawn=True) self.trigger_dirt_spawn(initial_spawn=True)
self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1 self._next_dirt_spawn = self.dirt_prop.spawn_frequency if self.dirt_prop.spawn_frequency else -1
@ -242,13 +240,13 @@ class DirtFactory(BaseFactory):
return all_cleaned, super_dict return all_cleaned, super_dict
return super_done, super_dict return super_done, super_dict
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations() additional_observations = super().observations_hook()
additional_observations.update({c.DIRT: self[c.DIRT].as_array()}) additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
return additional_observations return additional_observations
def gather_additional_info(self, agent: Agent) -> dict: def gather_additional_info(self, agent: Agent) -> dict:
event_reward_dict = super().additional_per_agent_reward(agent) event_reward_dict = super().per_agent_reward_hook(agent)
info_dict = dict() info_dict = dict()
dirt = [dirt.amount for dirt in self[c.DIRT]] dirt = [dirt.amount for dirt in self[c.DIRT]]
@ -280,8 +278,7 @@ if __name__ == '__main__':
max_local_amount=1, max_local_amount=1,
spawn_frequency=0, spawn_frequency=0,
max_spawn_ratio=0.05, max_spawn_ratio=0.05,
dirt_smear_amount=0.0, dirt_smear_amount=0.0
agent_can_interact=True
) )
obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True, obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True,
@ -294,13 +291,13 @@ if __name__ == '__main__':
global_timings = [] global_timings = []
for i in range(10): for i in range(10):
factory = DirtFactory(n_agents=1, done_at_collision=False, factory = DirtFactory(n_agents=10, done_at_collision=False,
level_name='rooms', max_steps=1000, level_name='rooms', max_steps=1000,
doors_have_area=False, doors_have_area=False,
obs_prop=obs_props, parse_doors=True, obs_prop=obs_props, parse_doors=True,
verbose=True, verbose=True,
mv_prop=move_props, dirt_prop=dirt_props, mv_prop=move_props, dirt_prop=dirt_props,
inject_agents=[TSPDirtAgent], # inject_agents=[TSPDirtAgent],
) )
# noinspection DuplicatedCode # noinspection DuplicatedCode
@ -318,11 +315,11 @@ if __name__ == '__main__':
env_state = factory.reset() env_state = factory.reset()
if render: if render:
factory.render() factory.render()
tsp_agent = factory.get_injected_agents()[0] # tsp_agent = factory.get_injected_agents()[0]
rwrd = 0 rwrd = 0
for agent_i_action in random_actions: for agent_i_action in random_actions:
agent_i_action = tsp_agent.predict() # agent_i_action = tsp_agent.predict()
env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action) env_state, step_rwrd, done_bool, info_obj = factory.step(agent_i_action)
rwrd += step_rwrd rwrd += step_rwrd
if render: if render:

View File

@ -0,0 +1,58 @@
from typing import Dict, List, Union
import numpy as np
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
from environments.factory.base.objects import Floor
from environments.factory.base.registers import Floors, Entities, EntityRegister
class Machines(EntityRegister):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Machine(Entity):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class StationaryMachinesDirtFactory(DirtFactory):
def __init__(self, *args, **kwargs):
self._machine_coords = [(6, 6), (12, 13)]
super().__init__(*args, **kwargs)
def entities_hook(self) -> Dict[(str, Entities)]:
super_entities = super().entities_hook()
return super_entities
def reset_hook(self) -> None:
pass
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
pass
def actions_hook(self) -> Union[Action, List[Action]]:
pass
def step_hook(self) -> (List[dict], dict):
pass
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
return super_per_agent_raw_observations
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
pass
def pre_step_hook(self) -> None:
pass
def post_step_hook(self) -> dict:
pass

View File

@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards from environments.helpers import Rewards as BaseRewards
from environments import helpers as h from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Tile from environments.factory.base.objects import Agent, Entity, Action, Floor
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
from environments.factory.base.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
@ -25,7 +25,7 @@ class Constants(BaseConstants):
class Actions(BaseActions): class Actions(BaseActions):
ITEM_ACTION = 'item_action' ITEM_ACTION = 'ITEMACTION'
class Rewards(BaseRewards): class Rewards(BaseRewards):
@ -62,7 +62,7 @@ class ItemRegister(EntityRegister):
_accepted_objects = Item _accepted_objects = Item
def spawn_items(self, tiles: List[Tile]): def spawn_items(self, tiles: List[Floor]):
items = [Item(tile, self) for tile in tiles] items = [Item(tile, self) for tile in tiles]
self.register_additional_items(items) self.register_additional_items(items)
@ -193,16 +193,16 @@ class ItemFactory(BaseFactory):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@property @property
def additional_actions(self) -> Union[Action, List[Action]]: def actions_hook(self) -> Union[Action, List[Action]]:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super_actions = super().additional_actions super_actions = super().actions_hook
super_actions.append(Action(str_ident=a.ITEM_ACTION)) super_actions.append(Action(str_ident=a.ITEM_ACTION))
return super_actions return super_actions
@property @property
def additional_entities(self) -> Dict[(str, Entities)]: def entities_hook(self) -> Dict[(str, Entities)]:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super_entities = super().additional_entities super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations] empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_drop_off_locations]
drop_offs = DropOffLocations.from_tiles( drop_offs = DropOffLocations.from_tiles(
@ -220,13 +220,13 @@ class ItemFactory(BaseFactory):
super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories}) super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories})
return super_entities return super_entities
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_per_agent_raw_observations(agent) additional_raw_observations = super().per_agent_raw_observations_hook(agent)
additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()}) additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()})
return additional_raw_observations return additional_raw_observations
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations() additional_observations = super().observations_hook()
additional_observations.update({c.ITEM: self[c.ITEM].as_array()}) additional_observations.update({c.ITEM: self[c.ITEM].as_array()})
additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()}) additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()})
return additional_observations return additional_observations
@ -240,21 +240,21 @@ class ItemFactory(BaseFactory):
valid = c.NOT_VALID valid = c.NOT_VALID
if valid: if valid:
self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.') self.print(f'{agent.name} just dropped of an item at {drop_off.pos}.')
info_dict = {f'{agent.name}_DROPOFF_VALID': 1} info_dict = {f'{agent.name}_DROPOFF_VALID': 1, 'DROPOFF_VALID': 1}
else: else:
self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.') self.print(f'{agent.name} just tried to drop off at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_DROPOFF_FAIL': 1} info_dict = {f'{agent.name}_DROPOFF_FAIL': 1, 'DROPOFF_FAIL': 1}
reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict) reward = dict(value=r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL, reason=a.ITEM_ACTION, info=info_dict)
return valid, reward return valid, reward
elif item := self[c.ITEM].by_pos(agent.pos): elif item := self[c.ITEM].by_pos(agent.pos):
item.change_register(inventory) item.change_register(inventory)
item.set_tile_to(self._NO_POS_TILE) item.set_tile_to(self._NO_POS_TILE)
self.print(f'{agent.name} just picked up an item at {agent.pos}') self.print(f'{agent.name} just picked up an item at {agent.pos}')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1} info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict) return c.VALID, dict(value=r.PICK_UP_VALID, reason=a.ITEM_ACTION, info=info_dict)
else: else:
self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.') self.print(f'{agent.name} just tried to pick up an item at {agent.pos}, but failed.')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1} info_dict = {f'{agent.name}_{a.ITEM_ACTION}_FAIL': 1, f'{a.ITEM_ACTION}_FAIL': 1}
return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict) return c.NOT_VALID, dict(value=r.PICK_UP_FAIL, reason=a.ITEM_ACTION, info=info_dict)
def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict): def do_additional_actions(self, agent: Agent, action: Action) -> (dict, dict):
@ -269,9 +269,9 @@ class ItemFactory(BaseFactory):
else: else:
return action_result return action_result
def do_additional_reset(self) -> None: def reset_hook(self) -> None:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super().do_additional_reset() super().reset_hook()
self._next_item_spawn = self.item_prop.spawn_frequency self._next_item_spawn = self.item_prop.spawn_frequency
self.trigger_item_spawn() self.trigger_item_spawn()
@ -284,9 +284,9 @@ class ItemFactory(BaseFactory):
else: else:
self.print('No Items are spawning, limit is reached.') self.print('No Items are spawning, limit is reached.')
def do_additional_step(self) -> (List[dict], dict): def step_hook(self) -> (List[dict], dict):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
super_reward_info = super().do_additional_step() super_reward_info = super().step_hook()
for item in list(self[c.ITEM].values()): for item in list(self[c.ITEM].values()):
if item.auto_despawn >= 1: if item.auto_despawn >= 1:
item.set_auto_despawn(item.auto_despawn-1) item.set_auto_despawn(item.auto_despawn-1)
@ -301,9 +301,9 @@ class ItemFactory(BaseFactory):
self._next_item_spawn = max(0, self._next_item_spawn-1) self._next_item_spawn = max(0, self._next_item_spawn-1)
return super_reward_info return super_reward_info
def render_additional_assets(self, mode='human'): def render_assets_hook(self, mode='human'):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets() additional_assets = super().render_assets_hook()
items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE] items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE]
additional_assets.extend(items) additional_assets.extend(items)
drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]] drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]]
@ -314,7 +314,7 @@ class ItemFactory(BaseFactory):
if __name__ == '__main__': if __name__ == '__main__':
from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties
render = False render = True
item_probs = ItemProperties(n_items=30, n_drop_off_locations=6) item_probs = ItemProperties(n_items=30, n_drop_off_locations=6)
@ -336,18 +336,18 @@ if __name__ == '__main__':
obs_space = factory.observation_space obs_space = factory.observation_space
obs_space_named = factory.named_observation_space obs_space_named = factory.named_observation_space
for epoch in range(4): for epoch in range(400):
random_actions = [[random.randint(0, n_actions) for _ random_actions = [[random.randint(0, n_actions) for _
in range(factory.n_agents)] for _ in range(factory.n_agents)] for _
in range(factory.max_steps + 1)] in range(factory.max_steps + 1)]
env_state = factory.reset() env_state = factory.reset()
r = 0 rwrd = 0
for agent_i_action in random_actions: for agent_i_action in random_actions:
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action) env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
r += step_r rwrd += step_r
if render: if render:
factory.render() factory.render()
if done_bool: if done_bool:
break break
print(f'Factory run {epoch} done, reward is:\n {r}') print(f'Factory run {epoch} done, reward is:\n {rwrd}')
pass pass

View File

@ -1,5 +1,6 @@
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from os import PathLike
from pathlib import Path from pathlib import Path
from typing import List, Dict, Union from typing import List, Dict, Union
@ -9,14 +10,17 @@ from environments.helpers import IGNORED_DF_COLUMNS
import pandas as pd import pandas as pd
from plotting.compare_runs import plot_single_run
class EnvMonitor(BaseCallback): class EnvMonitor(BaseCallback):
ext = 'png' ext = 'png'
def __init__(self, env): def __init__(self, env, filepath: Union[str, PathLike] = None):
super(EnvMonitor, self).__init__() super(EnvMonitor, self).__init__()
self.unwrapped = env self.unwrapped = env
self._filepath = filepath
self._monitor_df = pd.DataFrame() self._monitor_df = pd.DataFrame()
self._monitor_dicts = defaultdict(dict) self._monitor_dicts = defaultdict(dict)
@ -67,8 +71,10 @@ class EnvMonitor(BaseCallback):
pass pass
return return
def save_run(self, filepath: Union[Path, str]): def save_run(self, filepath: Union[Path, str], auto_plotting_keys=None):
filepath = Path(filepath) filepath = Path(filepath)
filepath.parent.mkdir(exist_ok=True, parents=True) filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f: with filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
if auto_plotting_keys:
plot_single_run(filepath, column_keys=auto_plotting_keys)

View File

@ -24,14 +24,12 @@ class EnvRecorder(BaseCallback):
self._entities = [entities] self._entities = [entities]
else: else:
self._entities = entities self._entities = entities
self.started = False
self.closed = False
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self.unwrapped, item) return getattr(self.unwrapped, item)
def reset(self): def reset(self):
self.unwrapped._record_episodes = True self.unwrapped.start_recording()
return self.unwrapped.reset() return self.unwrapped.reset()
def _on_training_start(self) -> None: def _on_training_start(self) -> None:
@ -57,6 +55,14 @@ class EnvRecorder(BaseCallback):
else: else:
pass pass
def step(self, actions):
step_result = self.unwrapped.step(actions)
# 0, 1, 2 , 3 = idx
# _, _, done_bool, info_obj = step_result
self._read_info(0, step_result[3])
self._read_done(0, step_result[2])
return step_result
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False): def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
filepath = Path(filepath) filepath = Path(filepath)
filepath.parent.mkdir(exist_ok=True, parents=True) filepath.parent.mkdir(exist_ok=True, parents=True)

View File

@ -10,6 +10,45 @@ from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
from plotting.plotting import prepare_plot from plotting.plotting import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob('*monitor*.pick'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'],
value_vars=columns, var_name="Measurement",
value_name="Score")
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False): def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path) run_path = Path(run_path)
df_list = list() df_list = list()
@ -37,7 +76,10 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
skip_n = round(df_melted['Episode'].max() * 0.02) skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0] df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex) if run_path.is_dir():
prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex)
elif run_path.exists() and run_path.is_file():
prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.') print('Plotting done.')

View File

@ -1,4 +1,5 @@
import seaborn as sns import seaborn as sns
import matplotlib as mpl
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
PALETTE = 10 * ( PALETTE = 10 * (
@ -21,7 +22,14 @@ PALETTE = 10 * (
def plot(filepath, ext='png'): def plot(filepath, ext='png'):
plt.tight_layout() plt.tight_layout()
figure = plt.gcf() figure = plt.gcf()
figure.savefig(str(filepath), format=ext) ax = plt.gca()
legends = [c for c in ax.get_children() if isinstance(c, mpl.legend.Legend)]
if legends:
figure.savefig(str(filepath), format=ext, bbox_extra_artists=(*legends,), bbox_inches='tight')
else:
figure.savefig(str(filepath), format=ext)
plt.show() plt.show()
plt.clf() plt.clf()
@ -30,7 +38,7 @@ def prepare_tex(df, hue, style, hue_order):
sns.set(rc={'text.usetex': True}, style='whitegrid') sns.set(rc={'text.usetex': True}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style) hue_order=hue_order, hue=hue, style=style)
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.tight_layout() plt.tight_layout()
return lineplot return lineplot
@ -48,6 +56,19 @@ def prepare_plt(df, hue, style, hue_order):
return lineplot return lineplot
def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
fig = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
lineplot.legend(hue_order, ncol=3, loc='lower center', title='Parameter Combinations', bbox_to_anchor=(0.5, -0.43))
plt.tight_layout()
return lineplot
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False): def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
df = results_df.copy() df = results_df.copy()
df[hue] = df[hue].str.replace('_', '-') df[hue] = df[hue].str.replace('_', '-')

View File

@ -4,7 +4,10 @@ from pathlib import Path
import yaml import yaml
from stable_baselines3 import A2C, PPO, DQN from stable_baselines3 import A2C, PPO, DQN
from environments.factory.factory_dirt import Constants as c
from environments.factory.factory_dirt import DirtFactory from environments.factory.factory_dirt import DirtFactory
from environments.logging.envmonitor import EnvMonitor
from environments.logging.recorder import EnvRecorder from environments.logging.recorder import EnvRecorder
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
@ -16,32 +19,35 @@ if __name__ == '__main__':
determin = False determin = False
render = True render = True
record = False record = False
seed = 67 verbose = True
seed = 13
n_agents = 1 n_agents = 1
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward') # out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
out_path = Path('study_out/test/dirt') out_path = Path('study_out/reload')
model_path = out_path model_path = out_path
with (out_path / f'env_params.json').open('r') as f: with (out_path / f'env_params.json').open('r') as f:
env_kwargs = yaml.load(f, Loader=yaml.FullLoader) env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
env_kwargs.update(additional_agent_placeholder=None, n_agents=n_agents, max_steps=150) env_kwargs.update(n_agents=n_agents, done_at_collision=False, verbose=verbose)
if gain_amount := env_kwargs.get('dirt_prop', {}).get('gain_amount', None):
env_kwargs['dirt_prop']['max_spawn_amount'] = gain_amount
del env_kwargs['dirt_prop']['gain_amount']
env_kwargs.update(record_episodes=record, done_at_collision=True)
this_model = out_path / 'model.zip' this_model = out_path / 'model.zip'
model_cls = PPO # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name) model_cls = PPO # next(val for key, val in h.MODEL_MAP.items() if key in out_path.parent.name)
models = [model_cls.load(this_model)] models = [model_cls.load(this_model)]
try:
# Legacy Cleanups
del env_kwargs['dirt_prop']['agent_can_interact']
env_kwargs['verbose'] = True
except KeyError:
pass
# Init Env # Init Env
with DirtFactory(**env_kwargs) as env: with DirtFactory(**env_kwargs) as env:
env = EnvRecorder(env) env = EnvMonitor(env)
env = EnvRecorder(env) if record else env
obs_shape = env.observation_space.shape obs_shape = env.observation_space.shape
# Evaluation Loop for i in range(n Episodes) # Evaluation Loop for i in range(n Episodes)
for episode in range(50): for episode in range(500):
env_state = env.reset() env_state = env.reset()
rew, done_bool = 0, False rew, done_bool = 0, False
while not done_bool: while not done_bool:
@ -55,7 +61,17 @@ if __name__ == '__main__':
rew += step_r rew += step_r
if render: if render:
env.render() env.render()
try:
door = next(x for x in env.unwrapped.unwrapped[c.DOORS] if x.is_open)
print('openDoor found')
except StopIteration:
pass
if done_bool: if done_bool:
break break
print(f'Factory run {episode} done, reward is:\n {rew}') print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
env.save_run(out_path / 'reload_monitor.pick',
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
if record:
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
print('all done') print('all done')

View File

@ -1,3 +1,4 @@
import itertools
import sys import sys
from pathlib import Path from pathlib import Path
@ -65,8 +66,8 @@ def load_model_run_baseline(policy_path, env_to_run):
if done_bool: if done_bool:
break break
print(f'Factory run {episode} done, reward is:\n {rew}') print(f'Factory run {episode} done, reward is:\n {rew}')
recorded_env_factory.save_run(filepath=policy_path / f'monitor.pick') recorded_env_factory.save_run(filepath=policy_path / f'baseline_monitor.pick')
recorded_env_factory.save_records(filepath=policy_path / f'recorder.json') recorded_env_factory.save_records(filepath=policy_path / f'baseline_recorder.json')
def load_model_run_combined(root_path, env_to_run, env_kwargs): def load_model_run_combined(root_path, env_to_run, env_kwargs):
@ -89,134 +90,156 @@ def load_model_run_combined(root_path, env_to_run, env_kwargs):
env_factory.named_observation_space, env_factory.named_observation_space,
*[x.named_observation_space for x in models]) *[x.named_observation_space for x in models])
monitored_env_factory = EnvMonitor(env_factory) env = EnvMonitor(env_factory)
recorded_env_factory = EnvRecorder(monitored_env_factory)
# Evaluation Loop for i in range(n Episodes) # Evaluation Loop for i in range(n Episodes)
for episode in range(5): for episode in range(5):
env_state = recorded_env_factory.reset() env_state = env.reset()
rew, done_bool = 0, False rew, done_bool = 0, False
while not done_bool: while not done_bool:
translated_observations = observation_translator(env_state) translated_observations = observation_translator(env_state)
actions = [model.predict(translated_observations[model_idx], deterministic=True)[0] actions = [model.predict(translated_observations[model_idx], deterministic=True)[0]
for model_idx, model in enumerate(models)] for model_idx, model in enumerate(models)]
translated_actions = action_translator(actions) translated_actions = action_translator(actions)
env_state, step_r, done_bool, info_obj = recorded_env_factory.step(translated_actions) env_state, step_r, done_bool, info_obj = env.step(translated_actions)
rew += step_r rew += step_r
if done_bool: if done_bool:
break break
print(f'Factory run {episode} done, reward is:\n {rew}') print(f'Factory run {episode} done, reward is:\n {rew}')
recorded_env_factory.save_run(filepath=root_path / f'monitor.pick') env.save_run(filepath=root_path / f'monitor_combined.pick')
recorded_env_factory.save_records(filepath=root_path / f'recorder.json') # env.save_records(filepath=root_path / f'recorder_combined.json')
if __name__ == '__main__': if __name__ == '__main__':
# What to do: # What to do:
train = True train = True
individual_run = True individual_run = False
combined_run = False combined_run = False
multi_env = False multi_env = False
train_steps = 2e6 train_steps = 1e6
frames_to_stack = 3 frames_to_stack = 3
# Define a global studi save path # Define a global studi save path
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}' paremters_of_interest = dict(
show_global_position_info=[True, False],
pomdp_r=[3],
cast_shadows=[True, False],
allow_diagonal_movement=[True],
parse_doors=[True, False],
doors_have_area=[True, False],
done_at_collision=[True, False]
)
keys, vals = zip(*paremters_of_interest.items())
def policy_model_kwargs(): # Then we find all permutations for those values
return dict() p = list(itertools.product(*vals))
# Define Global Env Parameters # Finally we can create out list of dicts
# Define properties object parameters result = [{keys[index]: entry[index] for index in range(len(entry))} for entry in p]
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
additional_agent_placeholder=None,
omit_agent_self=True,
frames_to_stack=frames_to_stack,
pomdp_r=2, cast_shadows=True)
move_props = MovementProperties(allow_diagonal_movement=True,
allow_square_movement=True,
allow_no_op=False)
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
clean_amount=0.34,
max_spawn_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
dirt_smear_amount=0.0, agent_can_interact=True)
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
max_agent_inventory_capacity=15)
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=True,
level_name='rooms', doors_have_area=True,
verbose=False,
mv_prop=move_props,
obs_prop=obs_props,
done_at_collision=False
)
# Bundle both environments with global kwargs and parameters for u in result:
env_map = {} file_name = '_'.join('_'.join([str(y)[0] for y in x]) for x in u.items())
env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props, study_root_path = Path(__file__).parent.parent / 'study_out' / file_name
**factory_kwargs.copy()))})
env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
**factory_kwargs.copy()))})
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
# **factory_kwargs.copy()))})
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
item_prop=item_props,
dirt_prop=dirt_props,
**factory_kwargs.copy()))})
env_names = list(env_map.keys())
# Train starts here ############################################################ # Model Kwargs
# Build Major Loop parameters, parameter versions, Env Classes and models policy_model_kwargs = dict(ent_coef=0.01)
if train:
for env_key in (env_key for env_key in env_map if 'combined' != env_key):
model_cls = h.MODEL_MAP['PPO']
combination_path = study_root_path / env_key
env_class, env_kwargs = env_map[env_key]
# Output folder # Define Global Env Parameters
if (combination_path / 'monitor.pick').exists(): # Define properties object parameters
continue obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT,
combination_path.mkdir(parents=True, exist_ok=True) additional_agent_placeholder=None,
omit_agent_self=True,
frames_to_stack=frames_to_stack,
pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'],
show_global_position_info=u['show_global_position_info'])
move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'],
allow_square_movement=True,
allow_no_op=False)
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
clean_amount=0.34,
max_spawn_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
dirt_smear_amount=0.0)
item_props = ItemProperties(n_items=10, spawn_frequency=30, n_drop_off_locations=2,
max_agent_inventory_capacity=15)
dest_props = DestProperties(n_dests=4, spawn_mode=DestModeOptions.GROUPED, spawn_frequency=1)
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'],
level_name='rooms', doors_have_area=u['doors_have_area'],
verbose=False,
mv_prop=move_props,
obs_prop=obs_props,
done_at_collision=u['done_at_collision']
)
if not multi_env: # Bundle both environments with global kwargs and parameters
env_factory = encapsule_env_factory(env_class, env_kwargs)() env_map = {}
else: env_map.update({'dirt': (DirtFactory, dict(dirt_prop=dirt_props,
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) **factory_kwargs.copy()),
for _ in range(6)], start_method="spawn") ['cleanup_valid', 'cleanup_fail'])})
# env_map.update({'item': (ItemFactory, dict(item_prop=item_props,
# **factory_kwargs.copy()),
# ['DROPOFF_FAIL', 'ITEMACTION_FAIL', 'DROPOFF_VALID', 'ITEMACTION_VALID'])})
# env_map.update({'dest': (DestFactory, dict(dest_prop=dest_props,
# **factory_kwargs.copy()))})
env_map.update({'combined': (DirtDestItemFactory, dict(dest_prop=dest_props,
item_prop=item_props,
dirt_prop=dirt_props,
**factory_kwargs.copy()))})
env_names = list(env_map.keys())
param_path = combination_path / f'env_params.json' # Train starts here ############################################################
try: # Build Major Loop parameters, parameter versions, Env Classes and models
env_factory.env_method('save_params', param_path) if train:
except AttributeError: for env_key in (env_key for env_key in env_map if 'combined' != env_key):
env_factory.save_params(param_path) model_cls = h.MODEL_MAP['PPO']
combination_path = study_root_path / env_key
env_class, env_kwargs, env_plot_keys = env_map[env_key]
# EnvMonitor Init # Output folder
callbacks = [EnvMonitor(env_factory)] if (combination_path / 'monitor.pick').exists():
continue
combination_path.mkdir(parents=True, exist_ok=True)
# Model Init if not multi_env:
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs(), env_factory = encapsule_env_factory(env_class, env_kwargs)()
verbose=1, seed=69, device='cpu') else:
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
for _ in range(6)], start_method="spawn")
# Model train param_path = combination_path / f'env_params.json'
model.learn(total_timesteps=int(train_steps), callback=callbacks) try:
env_factory.env_method('save_params', param_path)
except AttributeError:
env_factory.save_params(param_path)
# Model save # EnvMonitor Init
try: callbacks = [EnvMonitor(env_factory)]
model.named_action_space = env_factory.unwrapped.named_action_space
model.named_observation_space = env_factory.unwrapped.named_observation_space
except AttributeError:
model.named_action_space = env_factory.get_attr("named_action_space")[0]
model.named_observation_space = env_factory.get_attr("named_observation_space")[0]
save_path = combination_path / f'model.zip'
model.save(save_path)
# Monitor Save # Model Init
callbacks[0].save_run(combination_path / 'monitor.pick') model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
verbose=1, seed=69, device='cpu')
# Better be save then sorry: Clean up! # Model train
del env_factory, model model.learn(total_timesteps=int(train_steps), callback=callbacks)
import gc
gc.collect() # Model save
try:
model.named_action_space = env_factory.unwrapped.named_action_space
model.named_observation_space = env_factory.unwrapped.named_observation_space
except AttributeError:
model.named_action_space = env_factory.get_attr("named_action_space")[0]
model.named_observation_space = env_factory.get_attr("named_observation_space")[0]
save_path = combination_path / f'model.zip'
model.save(save_path)
# Monitor Save
callbacks[0].save_run(combination_path / 'monitor.pick',
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
# Better be save then sorry: Clean up!
del env_factory, model
import gc
gc.collect()
# Train ends here ############################################################ # Train ends here ############################################################