Adjustments and Documentation

This commit is contained in:
Steffen Illium 2022-04-11 16:15:44 +02:00
parent 3e19970a60
commit 0218f8f4e9
12 changed files with 394 additions and 182 deletions

View File

@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
np.argwhere(level_array == c.OCCUPIED_CELL), np.argwhere(level_array == c.OCCUPIED_CELL),
self._level_shape self._level_shape
) )
self._entities.register_additional_items({c.WALLS: walls}) self._entities.add_additional_items({c.WALLS: walls})
# Floor # Floor
floor = Floors.from_argwhere_coordinates( floor = Floors.from_argwhere_coordinates(
np.argwhere(level_array == c.FREE_CELL), np.argwhere(level_array == c.FREE_CELL),
self._level_shape self._level_shape
) )
self._entities.register_additional_items({c.FLOOR: floor}) self._entities.add_additional_items({c.FLOOR: floor})
# NOPOS # NOPOS
self._NO_POS_TILE = Floor(c.NO_POS, None) self._NO_POS_TILE = Floor(c.NO_POS, None)
@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area, doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
entity_kwargs=dict(context=floor) entity_kwargs=dict(context=floor)
) )
self._entities.register_additional_items({c.DOORS: doors}) self._entities.add_additional_items({c.DOORS: doors})
# Actions # Actions
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors) self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
if additional_actions := self.actions_hook: if additional_actions := self.actions_hook:
self._actions.register_additional_items(additional_actions) self._actions.add_additional_items(additional_actions)
# Agents # Agents
agents_to_spawn = self.n_agents-len(self._injected_agents) agents_to_spawn = self.n_agents-len(self._injected_agents)
@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
if self._injected_agents: if self._injected_agents:
initialized_injections = list() initialized_injections = list()
for i, injection in enumerate(self._injected_agents): for i, injection in enumerate(self._injected_agents):
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False)) agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
initialized_injections.append(agents[-1]) initialized_injections.append(agents[-1])
self._initialized_injections = initialized_injections self._initialized_injections = initialized_injections
self._entities.register_additional_items({c.AGENT: agents}) self._entities.add_additional_items({c.AGENT: agents})
if self.obs_prop.additional_agent_placeholder is not None: if self.obs_prop.additional_agent_placeholder is not None:
# TODO: Make this accept Lists for multiple placeholders # TODO: Make this accept Lists for multiple placeholders
@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
fill_value=self.obs_prop.additional_agent_placeholder) fill_value=self.obs_prop.additional_agent_placeholder)
) )
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder}) self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
# Additional Entitites from SubEnvs # Additional Entitites from SubEnvs
if additional_entities := self.entities_hook: if additional_entities := self.entities_hook:
self._entities.register_additional_items(additional_entities) self._entities.add_additional_items(additional_entities)
if self.obs_prop.show_global_position_info: if self.obs_prop.show_global_position_info:
global_positions = GlobalPositions(self._level_shape) global_positions = GlobalPositions(self._level_shape)
# This moved into the GlobalPosition object # This moved into the GlobalPosition object
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2) # obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
global_positions.spawn_global_position_objects(self[c.AGENT]) global_positions.spawn_global_position_objects(self[c.AGENT])
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions}) self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
# Return # Return
return self._entities return self._entities
@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
def _check_agent_move(self, agent, action: Action) -> (Floor, bool): def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
# Actions # Actions
x_diff, y_diff = h.ACTIONMAP[action.identifier] x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
x_new = agent.x + x_diff x_new = agent.x + x_diff
y_new = agent.y + y_diff y_new = agent.y + y_diff

View File

@ -72,15 +72,15 @@ class EnvObject(Object):
def encoding(self): def encoding(self):
return c.OCCUPIED_CELL return c.OCCUPIED_CELL
def __init__(self, register, **kwargs): def __init__(self, collection, **kwargs):
super(EnvObject, self).__init__(**kwargs) super(EnvObject, self).__init__(**kwargs)
self._register = register self._collection = collection
def change_register(self, register): def change_parent_collection(self, other_collection):
register.register_item(self) other_collection.add_item(self)
self._register.delete_env_object(self) self._collection.delete_env_object(self)
self._register = register self._collection = other_collection
return self._register == register return self._collection == other_collection
# With Rendering # With Rendering
@ -153,7 +153,7 @@ class MoveableEntity(Entity):
curr_tile.leave(self) curr_tile.leave(self)
self._tile = next_tile self._tile = next_tile
self._last_tile = curr_tile self._last_tile = curr_tile
self._register.notify_change_to_value(self) self._collection.notify_change_to_value(self)
return c.VALID return c.VALID
else: else:
return c.NOT_VALID return c.NOT_VALID
@ -371,13 +371,13 @@ class Door(Entity):
def _open(self): def _open(self):
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))]) self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
self._state = c.OPEN_DOOR self._state = c.OPEN_DOOR
self._register.notify_change_to_value(self) self._collection.notify_change_to_value(self)
self.time_to_close = self.auto_close_interval self.time_to_close = self.auto_close_interval
def _close(self): def _close(self):
self.connectivity.remove_node(self.pos) self.connectivity.remove_node(self.pos)
self._state = c.CLOSED_DOOR self._state = c.CLOSED_DOOR
self._register.notify_change_to_value(self) self._collection.notify_change_to_value(self)
def is_linked(self, old_pos, new_pos): def is_linked(self, old_pos, new_pos):
try: try:

View File

@ -13,11 +13,11 @@ from environments import helpers as h
from environments.helpers import Constants as c from environments.helpers import Constants as c
########################################################################## ##########################################################################
# ##################### Base Register Definition ####################### # # ################## Base Collections Definition ####################### #
########################################################################## ##########################################################################
class ObjectRegister: class ObjectCollection:
_accepted_objects = Object _accepted_objects = Object
@property @property
@ -25,59 +25,59 @@ class ObjectRegister:
return f'{self.__class__.__name__}' return f'{self.__class__.__name__}'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._register = dict() self._collection = dict()
def __len__(self): def __len__(self):
return len(self._register) return len(self._collection)
def __iter__(self): def __iter__(self):
return iter(self.values()) return iter(self.values())
def register_item(self, other: _accepted_objects): def add_item(self, other: _accepted_objects):
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \ assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \ f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,' f'but were {other.__class__}.,'
self._register.update({other.name: other}) self._collection.update({other.name: other})
return self return self
def register_additional_items(self, others: List[_accepted_objects]): def add_additional_items(self, others: List[_accepted_objects]):
for other in others: for other in others:
self.register_item(other) self.add_item(other)
return self return self
def keys(self): def keys(self):
return self._register.keys() return self._collection.keys()
def values(self): def values(self):
return self._register.values() return self._collection.values()
def items(self): def items(self):
return self._register.items() return self._collection.items()
def _get_index(self, item): def _get_index(self, item):
try: try:
return next(i for i, v in enumerate(self._register.values()) if v == item) return next(i for i, v in enumerate(self._collection.values()) if v == item)
except StopIteration: except StopIteration:
return None return None
def __getitem__(self, item): def __getitem__(self, item):
if isinstance(item, (int, np.int64, np.int32)): if isinstance(item, (int, np.int64, np.int32)):
if item < 0: if item < 0:
item = len(self._register) - abs(item) item = len(self._collection) - abs(item)
try: try:
return next(v for i, v in enumerate(self._register.values()) if i == item) return next(v for i, v in enumerate(self._collection.values()) if i == item)
except StopIteration: except StopIteration:
return None return None
try: try:
return self._register[item] return self._collection[item]
except KeyError: except KeyError:
return None return None
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}[{self._register}]' return f'{self.__class__.__name__}[{self._collection}]'
class EnvObjectRegister(ObjectRegister): class EnvObjectCollection(ObjectCollection):
_accepted_objects = EnvObject _accepted_objects = EnvObject
@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
is_blocking_light: bool = False, is_blocking_light: bool = False,
can_collide: bool = False, can_collide: bool = False,
can_be_shadowed: bool = True, **kwargs): can_be_shadowed: bool = True, **kwargs):
super(EnvObjectRegister, self).__init__(*args, **kwargs) super(EnvObjectCollection, self).__init__(*args, **kwargs)
self._shape = obs_shape self._shape = obs_shape
self._array = None self._array = None
self._individual_slices = individual_slices self._individual_slices = individual_slices
@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
self.can_be_shadowed = can_be_shadowed self.can_be_shadowed = can_be_shadowed
self.can_collide = can_collide self.can_collide = can_collide
def register_item(self, other: EnvObject): def add_item(self, other: EnvObject):
super(EnvObjectRegister, self).register_item(other) super(EnvObjectCollection, self).add_item(other)
if self._array is None: if self._array is None:
self._array = np.zeros((1, *self._shape)) self._array = np.zeros((1, *self._shape))
else: else:
@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
if self._individual_slices: if self._individual_slices:
self._array = np.delete(self._array, idx, axis=0) self._array = np.delete(self._array, idx, axis=0)
else: else:
self.notify_change_to_free(self._register[name]) self.notify_change_to_free(self._collection[name])
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions # Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
# in the observation array are result of enumeration. They can overide each other. # in the observation array are result of enumeration. They can overide each other.
# Todo: Find a better solution # Todo: Find a better solution
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister): if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
self._refresh_arrays() self._refresh_arrays()
del self._register[name] del self._collection[name]
def delete_env_object(self, env_object: EnvObject): def delete_env_object(self, env_object: EnvObject):
del self[env_object.name] del self[env_object.name]
@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
del self[name] del self[name]
class EntityRegister(EnvObjectRegister, ABC): class EntityCollection(EnvObjectCollection, ABC):
_accepted_objects = Entity _accepted_objects = Entity
@classmethod @classmethod
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs): def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
# objects_name = cls._accepted_objects.__name__ # objects_name = cls._accepted_objects.__name__
register_obj = cls(*args, **kwargs) collection = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, register_obj, str_ident=i, entities = [cls._accepted_objects(tile, collection, str_ident=i,
**entity_kwargs if entity_kwargs is not None else {}) **entity_kwargs if entity_kwargs is not None else {})
for i, tile in enumerate(tiles)] for i, tile in enumerate(tiles)]
register_obj.register_additional_items(entities) collection.add_additional_items(entities)
return register_obj return collection
@classmethod @classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ): def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
return [entity.tile for entity in self] return [entity.tile for entity in self]
def __init__(self, level_shape, *args, **kwargs): def __init__(self, level_shape, *args, **kwargs):
super(EntityRegister, self).__init__(level_shape, *args, **kwargs) super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
self._lazy_eval_transforms = [] self._lazy_eval_transforms = []
def __delitem__(self, name): def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj) obj.tile.leave(obj)
super(EntityRegister, self).__delitem__(name) super(EntityCollection, self).__delitem__(name)
def as_array(self): def as_array(self):
if self._lazy_eval_transforms: if self._lazy_eval_transforms:
@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
return None return None
class BoundEnvObjRegister(EnvObjectRegister, ABC): class BoundEnvObjCollection(EnvObjectCollection, ABC):
def __init__(self, entity_to_be_bound, *args, **kwargs): def __init__(self, entity_to_be_bound, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
return self._array[self.idx_by_entity(entity)] return self._array[self.idx_by_entity(entity)]
class MovingEntityObjectRegister(EntityRegister, ABC): class MovingEntityObjectCollection(EntityCollection, ABC):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs) super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
def notify_change_to_value(self, entity): def notify_change_to_value(self, entity):
super(MovingEntityObjectRegister, self).notify_change_to_value(entity) super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
if entity.last_pos != c.NO_POS: if entity.last_pos != c.NO_POS:
try: try:
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL) self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
########################################################################## ##########################################################################
# ################# Objects and Entity Registers ####################### # # ################# Objects and Entity Collection ###################### #
########################################################################## ##########################################################################
class GlobalPositions(EnvObjectRegister): class GlobalPositions(EnvObjectCollection):
_accepted_objects = GlobalPosition _accepted_objects = GlobalPosition
@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
global_positions = [self._accepted_objects(self._shape, agent, self) global_positions = [self._accepted_objects(self._shape, agent, self)
for _, agent in enumerate(agents)] for _, agent in enumerate(agents)]
# noinspection PyTypeChecker # noinspection PyTypeChecker
self.register_additional_items(global_positions) self.add_additional_items(global_positions)
def summarize_states(self, n_steps=None): def summarize_states(self, n_steps=None):
return {} return {}
@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
return None return None
class PlaceHolders(EnvObjectRegister): class PlaceHolders(EnvObjectCollection):
_accepted_objects = PlaceHolder _accepted_objects = PlaceHolder
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
# objects_name = cls._accepted_objects.__name__ # objects_name = cls._accepted_objects.__name__
if isinstance(values, (str, numbers.Number)): if isinstance(values, (str, numbers.Number)):
values = [values] values = [values]
register_obj = cls(*args, **kwargs) collection = cls(*args, **kwargs)
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value, objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
**object_kwargs if object_kwargs is not None else {}) **object_kwargs if object_kwargs is not None else {})
for i, value in enumerate(values)] for i, value in enumerate(values)]
register_obj.register_additional_items(objects) collection.add_additional_items(objects)
return register_obj return collection
# noinspection DuplicatedCode # noinspection DuplicatedCode
def as_array(self): def as_array(self):
@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
return self._array return self._array
class Entities(ObjectRegister): class Entities(ObjectCollection):
_accepted_objects = EntityRegister _accepted_objects = EntityCollection
@property @property
def arrays(self): def arrays(self):
@ -352,7 +352,7 @@ class Entities(ObjectRegister):
@property @property
def names(self): def names(self):
return list(self._register.keys()) return list(self._collection.keys())
def __init__(self): def __init__(self):
super(Entities, self).__init__() super(Entities, self).__init__()
@ -360,21 +360,21 @@ class Entities(ObjectRegister):
def iter_individual_entitites(self): def iter_individual_entitites(self):
return iter((x for sublist in self.values() for x in sublist)) return iter((x for sublist in self.values() for x in sublist))
def register_item(self, other: dict): def add_item(self, other: dict):
assert not any([key for key in other.keys() if key in self.keys()]), \ assert not any([key for key in other.keys() if key in self.keys()]), \
"This group of entities has already been registered!" "This group of entities has already been added!"
self._register.update(other) self._collection.update(other)
return self return self
def register_additional_items(self, others: Dict): def add_additional_items(self, others: Dict):
return self.register_item(others) return self.add_item(others)
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None] found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
return found_entities return found_entities
class Walls(EntityRegister): class Walls(EntityCollection):
_accepted_objects = Wall _accepted_objects = Wall
def as_array(self): def as_array(self):
@ -396,7 +396,7 @@ class Walls(EntityRegister):
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs): def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs) tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker # noinspection PyTypeChecker
tiles.register_additional_items( tiles.add_additional_items(
[cls._accepted_objects(pos, tiles) [cls._accepted_objects(pos, tiles)
for pos in argwhere_coordinates] for pos in argwhere_coordinates]
) )
@ -441,7 +441,7 @@ class Floors(Walls):
return {} return {}
class Agents(MovingEntityObjectRegister): class Agents(MovingEntityObjectCollection):
_accepted_objects = Agent _accepted_objects = Agent
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
old_agent = self[key] old_agent = self[key]
self[key].tile.leave(self[key]) self[key].tile.leave(self[key])
agent._name = old_agent.name agent._name = old_agent.name
self._register[agent.name] = agent self._collection[agent.name] = agent
class Doors(EntityRegister): class Doors(EntityCollection):
def __init__(self, *args, have_area: bool = False, **kwargs): def __init__(self, *args, have_area: bool = False, **kwargs):
self.have_area = have_area self.have_area = have_area
@ -490,7 +490,7 @@ class Doors(EntityRegister):
return super(Doors, self).as_array() return super(Doors, self).as_array()
class Actions(ObjectRegister): class Actions(ObjectCollection):
_accepted_objects = Action _accepted_objects = Action
@property @property
@ -507,22 +507,22 @@ class Actions(ObjectRegister):
# Move this to Baseclass, Env init? # Move this to Baseclass, Env init?
if self.allow_square_movement: if self.allow_square_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction) self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.square_move()]) for direction in h.EnvActions.square_move()])
if self.allow_diagonal_movement: if self.allow_diagonal_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction) self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.diagonal_move()]) for direction in h.EnvActions.diagonal_move()])
self._movement_actions = self._register.copy() self._movement_actions = self._collection.copy()
if self.can_use_doors: if self.can_use_doors:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)]) self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
if self.allow_no_op: if self.allow_no_op:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)]) self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
def is_moving_action(self, action: Union[int]): def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values() return action in self.movement_actions.values()
class Zones(ObjectRegister): class Zones(ObjectCollection):
@property @property
def accounting_zones(self): def accounting_zones(self):
@ -551,5 +551,5 @@ class Zones(ObjectRegister):
def __getitem__(self, item): def __getitem__(self, item):
return self._zone_slices[item] return self._zone_slices[item]
def register_additional_items(self, other: Union[str, List[str]]): def add_additional_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.') raise AttributeError('You are not allowed to add additional Zones in runtime.')

View File

@ -4,7 +4,7 @@ import numpy as np
from environments.factory.base.base_factory import BaseFactory from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityRegister, EnvObjectRegister from environments.factory.base.registers import EntityCollection, EnvObjectCollection
from environments.factory.base.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as BaseConstants from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions from environments.helpers import EnvActions as BaseActions
@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject):
if self.charge_level != 0: if self.charge_level != 0:
# noinspection PyTypeChecker # noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level) self.charge_level = max(0, amount + self.charge_level)
self._register.notify_change_to_value(self) self._collection.notify_change_to_value(self)
return c.VALID return c.VALID
else: else:
return c.NOT_VALID return c.NOT_VALID
@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject):
return attr_dict return attr_dict
class BatteriesRegister(EnvObjectRegister): class BatteriesRegister(EnvObjectCollection):
_accepted_objects = Battery _accepted_objects = Battery
@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister):
def spawn_batteries(self, agents, initial_charge_level): def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)] batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
self.register_additional_items(batteries) self.add_additional_items(batteries)
def summarize_states(self, n_steps=None): def summarize_states(self, n_steps=None):
# as dict with additional nesting # as dict with additional nesting
@ -140,7 +140,7 @@ class ChargePod(Entity):
return summary return summary
class ChargePods(EntityRegister): class ChargePods(EntityCollection):
_accepted_objects = ChargePod _accepted_objects = ChargePod

View File

@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as BaseConstants from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions from environments.helpers import EnvActions as BaseActions
from environments.factory.base.objects import Agent, Entity, Action from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.base.registers import Entities, EntityRegister from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
@ -73,7 +73,7 @@ class Destination(Entity):
return state_summary return state_summary
class Destinations(EntityRegister): class Destinations(EntityCollection):
_accepted_objects = Destination _accepted_objects = Destination
@ -208,13 +208,13 @@ class DestFactory(BaseFactory):
n_dest_to_spawn = len(destinations_to_spawn) n_dest_to_spawn = len(destinations_to_spawn)
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED: if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations) self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn: for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest] del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned') self.print(f'{n_dest_to_spawn} new destinations have been spawned')
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests: elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]] destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations) self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn: for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest] del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned') self.print(f'{n_dest_to_spawn} new destinations have been spawned')
@ -231,7 +231,7 @@ class DestFactory(BaseFactory):
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1) self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DEST].values()): for dest in list(self[c.DEST].values()):
if dest.is_considered_reached: if dest.is_considered_reached:
dest.change_register(self[c.DEST]) dest.change_parent_collection(self[c.DEST])
self._dest_spawn_timer[dest.name] = 0 self._dest_spawn_timer[dest.name] = 0
self.print(f'{dest.name} is reached now, removing...') self.print(f'{dest.name} is reached now, removing...')
else: else:

View File

@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
from environments.factory.base.base_factory import BaseFactory from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Floor from environments.factory.base.objects import Agent, Action, Entity, Floor
from environments.factory.base.registers import Entities, EntityRegister from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
from environments.utility_classes import ObservationProperties from environments.utility_classes import ObservationProperties
@ -61,7 +61,7 @@ class Dirt(Entity):
def set_new_amount(self, amount): def set_new_amount(self, amount):
self._amount = amount self._amount = amount
self._register.notify_change_to_value(self) self._collection.notify_change_to_value(self)
def summarize_state(self, **kwargs): def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs) state_dict = super().summarize_state(**kwargs)
@ -69,7 +69,7 @@ class Dirt(Entity):
return state_dict return state_dict
class DirtRegister(EntityRegister): class DirtRegister(EntityCollection):
_accepted_objects = Dirt _accepted_objects = Dirt
@ -93,7 +93,7 @@ class DirtRegister(EntityRegister):
dirt = self.by_pos(tile.pos) dirt = self.by_pos(tile.pos)
if dirt is None: if dirt is None:
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount) dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
self.register_item(dirt) self.add_item(dirt)
else: else:
new_value = dirt.amount + self.dirt_properties.max_spawn_amount new_value = dirt.amount + self.dirt_properties.max_spawn_amount
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount)) dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))

View File

@ -5,10 +5,10 @@ import numpy as np
from environments.factory.base.objects import Agent, Entity, Action from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
from environments.factory.base.objects import Floor from environments.factory.base.objects import Floor
from environments.factory.base.registers import Floors, Entities, EntityRegister from environments.factory.base.registers import Floors, Entities, EntityCollection
class Machines(EntityRegister): class Machines(EntityCollection):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions from environments.helpers import EnvActions as BaseActions
from environments import helpers as h from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Floor from environments.factory.base.objects import Agent, Entity, Action, Floor
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
from environments.factory.base.renderer import RenderEntity from environments.factory.base.renderer import RenderEntity
@ -53,17 +53,17 @@ class Item(Entity):
self._auto_despawn = auto_despawn self._auto_despawn = auto_despawn
def set_tile_to(self, no_pos_tile): def set_tile_to(self, no_pos_tile):
assert self._register.__class__.__name__ != ItemRegister.__class__ assert self._collection.__class__.__name__ != ItemRegister.__class__
self._tile = no_pos_tile self._tile = no_pos_tile
class ItemRegister(EntityRegister): class ItemRegister(EntityCollection):
_accepted_objects = Item _accepted_objects = Item
def spawn_items(self, tiles: List[Floor]): def spawn_items(self, tiles: List[Floor]):
items = [Item(tile, self) for tile in tiles] items = [Item(tile, self) for tile in tiles]
self.register_additional_items(items) self.add_additional_items(items)
def despawn_items(self, items: List[Item]): def despawn_items(self, items: List[Item]):
items = [items] if isinstance(items, Item) else items items = [items] if isinstance(items, Item) else items
@ -71,7 +71,7 @@ class ItemRegister(EntityRegister):
del self[item] del self[item]
class Inventory(BoundEnvObjRegister): class Inventory(BoundEnvObjCollection):
@property @property
def name(self): def name(self):
@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister):
return item_to_pop return item_to_pop
class Inventories(ObjectRegister): class Inventories(ObjectCollection):
_accepted_objects = Inventory _accepted_objects = Inventory
is_blocking_light = False is_blocking_light = False
@ -114,7 +114,7 @@ class Inventories(ObjectRegister):
def spawn_inventories(self, agents, capacity): def spawn_inventories(self, agents, capacity):
inventories = [self._accepted_objects(agent, capacity, self._obs_shape) inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
for _, agent in enumerate(agents)] for _, agent in enumerate(agents)]
self.register_additional_items(inventories) self.add_additional_items(inventories)
def idx_by_entity(self, entity): def idx_by_entity(self, entity):
try: try:
@ -161,7 +161,7 @@ class DropOffLocation(Entity):
return super().summarize_state(n_steps=n_steps) return super().summarize_state(n_steps=n_steps)
class DropOffLocations(EntityRegister): class DropOffLocations(EntityCollection):
_accepted_objects = DropOffLocation _accepted_objects = DropOffLocation
@ -250,7 +250,7 @@ class ItemFactory(BaseFactory):
reason=a.ITEM_ACTION, info=info_dict) reason=a.ITEM_ACTION, info=info_dict)
return valid, reward return valid, reward
elif item := self[c.ITEM].by_pos(agent.pos): elif item := self[c.ITEM].by_pos(agent.pos):
item.change_register(inventory) item.change_parent_collection(inventory)
item.set_tile_to(self._NO_POS_TILE) item.set_tile_to(self._NO_POS_TILE)
self.print(f'{agent.name} just picked up an item at {agent.pos}') self.print(f'{agent.name} just picked up an item at {agent.pos}')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1} info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}

View File

@ -7,47 +7,76 @@ import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from stable_baselines3 import PPO, DQN, A2C from stable_baselines3 import PPO, DQN, A2C
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
LEVELS_DIR = 'levels' """
STEPS_START = 1 This file is used for:
1. string based definition
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] Use a class like `Constants`, to define attributes, which then reveal strings.
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount', These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
'dirty_tile_count', 'terminal_observation', 'episode'] When defining new envs, use class inheritance.
2. utility function definition
There are static utility functions which are not bound to a specific environment.
In this file they are defined to be used across the entire package.
"""
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
# Not used anymore? Clean!
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
'episode']
# Constants
class Constants: class Constants:
WALL = '#'
WALLS = 'Walls'
FLOOR = 'Floor'
DOOR = 'D'
DANGER_ZONE = 'x'
LEVEL = 'Level'
AGENT = 'Agent'
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
GLOBAL_POSITION = 'GLOBAL_POSITION'
FREE_CELL = 0
OCCUPIED_CELL = 1
SHADOWED_CELL = -1
ACCESS_DOOR_CELL = 1/3
OPEN_DOOR_CELL = 2/3
CLOSED_DOOR_CELL = 3/3
NO_POS = (-9999, -9999)
DOORS = 'Doors' """
CLOSED_DOOR = 'closed' String based mapping. Use these to handle keys or define values, which can be then be used globaly.
OPEN_DOOR = 'open' Please use class inheritance when defining new environments.
ACCESS_DOOR = 'access' """
ACTION = 'action' WALL = '#' # Wall tile identifier for resolving the string based map files.
COLLISION = 'collision' DOOR = 'D' # Door identifier for resolving the string based map files.
VALID = True DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files.
NOT_VALID = False
WALLS = 'Walls' # Identifier of Wall-objects and sets (collections).
FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections).
DOORS = 'Doors' # Identifier of Door-objects and sets (collections).
LEVEL = 'Level' # Identifier of Level-objects and sets (collections).
AGENT = 'Agent' # Identifier of Agent-objects and sets (collections).
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections).
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
FREE_CELL = 0 # Free-Cell value used in observation
OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
SHADOWED_CELL = -1 # Shadowed-Cell value used in observation
ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation
OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation
CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation
NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid)
CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state
OPEN_DOOR = 'open' # Identifier to compare door-is-open state
# ACCESS_DOOR = 'access' # Identifier to compare access positions
ACTION = 'action' # Identifier of Action-objects and sets (collections).
COLLISION = 'collision' # Identifier to use in the context of collitions.
VALID = True # Identifier to rename boolean values in the context of actions.
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
class EnvActions: class EnvActions:
"""
String based mapping. Use these to identifiy actions, can be used globaly.
Please use class inheritance when defining new environments with new actions.
"""
# Movements # Movements
NORTH = 'north' NORTH = 'north'
EAST = 'east' EAST = 'east'
@ -63,24 +92,77 @@ class EnvActions:
NOOP = 'no_op' NOOP = 'no_op'
USE_DOOR = 'use_door' USE_DOOR = 'use_door'
_ACTIONMAP = defaultdict(lambda: (0, 0),
{NORTH: (-1, 0), NORTHEAST: (-1, 1),
EAST: (0, 1), SOUTHEAST: (1, 1),
SOUTH: (1, 0), SOUTHWEST: (1, -1),
WEST: (0, -1), NORTHWEST: (-1, -1)
}
)
@classmethod @classmethod
def is_move(cls, other): def is_move(cls, action):
return any([other == direction for direction in cls.movement_actions()]) """
Classmethod; checks if given action is a movement action or not. Depending on the env. configuration,
Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal.
:param action: Action to be checked
:type action: str
:return: Whether the given action is a movement action.
:rtype: bool
"""
return any([action == direction for direction in cls.movement_actions()])
@classmethod @classmethod
def square_move(cls): def square_move(cls):
"""
Classmethod; return a list of movement actions that are considered square or `manhattan` style movements.
:return: A list of movement actions.
:rtype: list(str)
"""
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
@classmethod @classmethod
def diagonal_move(cls): def diagonal_move(cls):
"""
Classmethod; return a list of movement actions that are considered diagonal movements.
:return: A list of movement actions.
:rtype: list(str)
"""
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
@classmethod @classmethod
def movement_actions(cls): def movement_actions(cls):
"""
Classmethod; return a list of all available movement actions.
Please note, that this is indipendent from the env. properties
:return: A list of movement actions.
:rtype: list(str)
"""
return list(itertools.chain(cls.square_move(), cls.diagonal_move())) return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
@classmethod
def resolve_movement_action_to_coords(cls, action):
"""
Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for.
How does the current entity coordinate change if it performs the given action?
Please note, this is indipendent from the env. properties
:return: Delta coorinates.
:rtype: tuple(int, int)
"""
return cls._ACTIONMAP[action]
class RewardsBase(NamedTuple): class RewardsBase(NamedTuple):
"""
Value based mapping. Use these to define reward values for specific conditions (i.e. the action
in a given context), can be used globaly.
Please use class inheritance when defining new environments with new rewards.
"""
MOVEMENTS_VALID: float = -0.001 MOVEMENTS_VALID: float = -0.001
MOVEMENTS_FAIL: float = -0.05 MOVEMENTS_FAIL: float = -0.05
NOOP: float = -0.01 NOOP: float = -0.01
@ -89,23 +171,31 @@ class RewardsBase(NamedTuple):
COLLISION: float = -0.5 COLLISION: float = -0.5
m = EnvActions
c = Constants
ACTIONMAP = defaultdict(lambda: (0, 0),
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1),
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
}
)
class ObservationTranslator: class ObservationTranslator:
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict], def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
*per_agent_named_obs_space: Dict[str, dict], *per_agent_named_obs_spaces: Dict[str, dict],
placeholder_fill_value: Union[int, str] = 'N'): placeholder_fill_value: Union[int, str] = 'N'):
"""
This is a helper class, which converts agents observations from joined environments.
For example, agents trained in different environments may expect different observations.
This class translates from larger observations spaces to smaller.
A string identifier based approach is used.
Currently, it is not possible to mix different obs shapes.
:param obs_shape_2d: The shape of the observation the agents expect.
:type obs_shape_2d: tuple(int, int)
:param this_named_observation_space: `Named observation space` of the joined environment.
:type this_named_observation_space: Dict[str, dict]
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N')
"""
assert len(obs_shape_2d) == 2 assert len(obs_shape_2d) == 2
self.obs_shape = obs_shape_2d self.obs_shape = obs_shape_2d
if isinstance(placeholder_fill_value, str): if isinstance(placeholder_fill_value, str):
@ -119,7 +209,7 @@ class ObservationTranslator:
self.random_fill = None self.random_fill = None
self._this_named_obs_space = this_named_observation_space self._this_named_obs_space = this_named_observation_space
self._per_agent_named_obs_space = list(per_agent_named_obs_space) self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
def translate_observation(self, agent_idx: int, obs: np.ndarray): def translate_observation(self, agent_idx: int, obs: np.ndarray):
target_obs_space = self._per_agent_named_obs_space[agent_idx] target_obs_space = self._per_agent_named_obs_space[agent_idx]
@ -137,6 +227,19 @@ class ObservationTranslator:
class ActionTranslator: class ActionTranslator:
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]): def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
"""
This is a helper class, which converts agents action spaces to a joined environments action space.
For example, agents trained in different environments may have different action spaces.
This class translates from smaller individual agent action spaces to larger joined spaces.
A string identifier based approach is used.
:param target_named_action_space: Joined `Named action space` for the current environment.
:type target_named_action_space: Dict[str, dict]
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
:type per_agent_named_action_space: Dict[str, dict]
"""
self._target_named_action_space = target_named_action_space self._target_named_action_space = target_named_action_space
self._per_agent_named_action_space = list(per_agent_named_action_space) self._per_agent_named_action_space = list(per_agent_named_action_space)
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space] self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
@ -155,6 +258,16 @@ class ActionTranslator:
# Utility functions # Utility functions
def parse_level(path): def parse_level(path):
"""
Given the path to a strin based `level` or `map` representation, this function reads the content.
Cleans `space`, checks for equal length of each row and returns a list of lists.
:param path: Path to the `level` or `map` file on harddrive.
:type path: os.Pathlike
:return: The read string representation of the `level` or `map`
:rtype: List[List[str]]
"""
with path.open('r') as lvl: with path.open('r') as lvl:
level = list(map(lambda x: list(x.strip()), lvl.readlines())) level = list(map(lambda x: list(x.strip()), lvl.readlines()))
if len(set([len(line) for line in level])) > 1: if len(set([len(line) for line in level])) > 1:
@ -162,29 +275,56 @@ def parse_level(path):
return level return level
def one_hot_level(level, wall_char: str = c.WALL): def one_hot_level(level, wall_char: str = Constants.WALL):
"""
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
Can be changed to filter for any symbol.
:param level: String based level representation (list of lists, see function `parse_level`).
:param wall_char: List[List[str]]
:return: Binary numpy array
:rtype: np.typing._array_like.ArrayLike
"""
grid = np.array(level) grid = np.array(level)
binary_grid = np.zeros(grid.shape, dtype=np.int8) binary_grid = np.zeros(grid.shape, dtype=np.int8)
binary_grid[grid == wall_char] = c.OCCUPIED_CELL binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL
return binary_grid return binary_grid
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]): def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
"""
Given a slice (2-D Arraylike object)
:param slice_to_check_against: The slice to check for accessability
:type slice_to_check_against: np.typing._array_like.ArrayLike
:param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys.
:type position_to_check: tuple(int, int)
:return: Whether a position can be moved to.
:rtype: bool
"""
x_pos, y_pos = position_to_check x_pos, y_pos = position_to_check
# Check if agent colides with grid boundrys # Check if agent colides with grid boundrys
valid = not ( valid = not (
x_pos < 0 or y_pos < 0 x_pos < 0 or y_pos < 0
or x_pos >= slice_to_check_against.shape[0] or x_pos >= slice_to_check_against.shape[0]
or y_pos >= slice_to_check_against.shape[0] or y_pos >= slice_to_check_against.shape[1]
) )
# Check for collision with level walls # Check for collision with level walls
valid = valid and not slice_to_check_against[x_pos, y_pos] valid = valid and not slice_to_check_against[x_pos, y_pos]
return c.VALID if valid else c.NOT_VALID return Constants.VALID if valid else Constants.NOT_VALID
def asset_str(agent): def asset_str(agent):
"""
FIXME @ romue
"""
# What does this abonimation do? # What does this abonimation do?
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]): # if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
# print('error') # print('error')
@ -192,33 +332,50 @@ def asset_str(agent):
action = step_result['action_name'] action = step_result['action_name']
valid = step_result['action_valid'] valid = step_result['action_valid']
col_names = [x.name for x in step_result['collisions']] col_names = [x.name for x in step_result['collisions']]
if any(c.AGENT in name for name in col_names): if any(Constants.AGENT in name for name in col_names):
return 'agent_collision', 'blank' return 'agent_collision', 'blank'
elif not valid or c.LEVEL in col_names or c.AGENT in col_names: elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names:
return c.AGENT, 'invalid' return Constants.AGENT, 'invalid'
elif valid and not EnvActions.is_move(action): elif valid and not EnvActions.is_move(action):
return c.AGENT, 'valid' return Constants.AGENT, 'valid'
elif valid and EnvActions.is_move(action): elif valid and EnvActions.is_move(action):
return c.AGENT, 'move' return Constants.AGENT, 'move'
else: else:
return c.AGENT, 'idle' return Constants.AGENT, 'idle'
else: else:
return c.AGENT, 'idle' return Constants.AGENT, 'idle'
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
"""
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
There are three combinations of settings:
Allow all neigbors: Distance(a, b) <= sqrt(2)
Allow only manhattan: Distance(a, b) == 1
Allow only euclidean: Distance(a, b) == sqrt(2)
:param coordiniates_or_tiles: A set of coordinates.
:type coordiniates_or_tiles: Tiles
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
:type: bool
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
:type: bool
:return: A graph with nodes that are conneceted as specified by the parameters.
:rtype: nx.Graph
"""
assert allow_euclidean_connections or allow_manhattan_connections assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'): if hasattr(coordiniates_or_tiles, 'positions'):
coordiniates_or_tiles = coordiniates_or_tiles.positions coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2) possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph() graph = nx.Graph()
for a, b in possible_connections: for a, b in possible_connections:
diff = abs(np.subtract(a, b)) diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if not max(diff) > 1: if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
if allow_manhattan_connections and allow_euclidean_connections: graph.add_edge(a, b)
graph.add_edge(a, b) elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff): graph.add_edge(a, b)
graph.add_edge(a, b) elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff): graph.add_edge(a, b)
graph.add_edge(a, b)
return graph return graph

View File

@ -4,6 +4,22 @@ from gym.wrappers.frame_stack import FrameStack
class AgentRenderOptions(object): class AgentRenderOptions(object):
"""
Class that specifies the available options for the way agents are represented in the env observation.
SEPERATE:
Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot)
COMBINED:
For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE))
LEVEL:
The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall)
NOT:
The position of individual agents can not be read from the observation.
"""
SEPERATE = 'seperate' SEPERATE = 'seperate'
COMBINED = 'combined' COMBINED = 'combined'
LEVEL = 'lvl' LEVEL = 'lvl'
@ -11,24 +27,61 @@ class AgentRenderOptions(object):
class MovementProperties(NamedTuple): class MovementProperties(NamedTuple):
"""
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
"""
"""Allow the manhattan style movement on a grid (move to cells that are connected by square edges)."""
allow_square_movement: bool = True allow_square_movement: bool = True
"""Allow diagonal movement on the grid (move to cells that are connected by square corners)."""
allow_diagonal_movement: bool = False allow_diagonal_movement: bool = False
"""Allow the agent to just do nothing; not move (NO-OP)."""
allow_no_op: bool = False allow_no_op: bool = False
class ObservationProperties(NamedTuple): class ObservationProperties(NamedTuple):
# Todo: Add Description """
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
"""
"""How to represent agents in the observation space. This may also alters the obs-shape."""
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
omit_agent_self: bool = True omit_agent_self: bool = True
"""Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs.
The additional slice can be filled with any number"""
additional_agent_placeholder: Union[None, str, int] = None additional_agent_placeholder: Union[None, str, int] = None
"""Whether to cast shadows (make floortiles and items hidden).; """
cast_shadows: bool = True cast_shadows: bool = True
"""Frame Stacking is a methode do give some temporal information to the agents.
This paramters controls how many "old-frames" """
frames_to_stack: int = 0 frames_to_stack: int = 0
pomdp_r: int = 0
"""Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken
accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`.
A 'pomdp_r' of 0 always returns the full env == no partial observability."""
pomdp_r: int = 2
"""Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be
operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option.
However, this is not necesarry at all.
"""
indicate_door_area: bool = False indicate_door_area: bool = False
"""Whether to add the agents normalized global position as float values (2,1) to a seperate information slice.
More optional informations are to come.
"""
show_global_position_info: bool = False show_global_position_info: bool = False
class MarlFrameStack(gym.ObservationWrapper): class MarlFrameStack(gym.ObservationWrapper):
"""todo @romue404"""
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)

View File

@ -215,7 +215,7 @@ if __name__ == '__main__':
clean_amount=0.34, clean_amount=0.34,
max_spawn_amount=0.1, max_global_amount=20, max_spawn_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05, max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
dirt_smear_amount=0.0, agent_can_interact=True) dirt_smear_amount=0.0)
item_props = ItemProperties(n_items=10, item_props = ItemProperties(n_items=10,
spawn_frequency=30, n_drop_off_locations=2, spawn_frequency=30, n_drop_off_locations=2,
max_agent_inventory_capacity=15) max_agent_inventory_capacity=15)
@ -349,6 +349,7 @@ if __name__ == '__main__':
# Env Init & Model kwargs definition # Env Init & Model kwargs definition
if model_cls.__name__ in ["PPO", "A2C"]: if model_cls.__name__ in ["PPO", "A2C"]:
# env_factory = env_class(**env_kwargs) # env_factory = env_class(**env_kwargs)
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs) env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
for _ in range(6)], start_method="spawn") for _ in range(6)], start_method="spawn")
model_kwargs = policy_model_kwargs() model_kwargs = policy_model_kwargs()

View File

@ -213,7 +213,8 @@ if __name__ == '__main__':
env_factory.save_params(param_path) env_factory.save_params(param_path)
# EnvMonitor Init # EnvMonitor Init
callbacks = [EnvMonitor(env_factory)] env_monitor = EnvMonitor(env_factory)
callbacks = [env_monitor]
# Model Init # Model Init
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs, model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
@ -233,7 +234,7 @@ if __name__ == '__main__':
model.save(save_path) model.save(save_path)
# Monitor Save # Monitor Save
callbacks[0].save_run(combination_path / 'monitor.pick', env_monitor.save_run(combination_path / 'monitor.pick',
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys) auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
# Better be save then sorry: Clean up! # Better be save then sorry: Clean up!