Adjustments and Documentation
This commit is contained in:
parent
3e19970a60
commit
0218f8f4e9
@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
|
|||||||
np.argwhere(level_array == c.OCCUPIED_CELL),
|
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.WALLS: walls})
|
self._entities.add_additional_items({c.WALLS: walls})
|
||||||
|
|
||||||
# Floor
|
# Floor
|
||||||
floor = Floors.from_argwhere_coordinates(
|
floor = Floors.from_argwhere_coordinates(
|
||||||
np.argwhere(level_array == c.FREE_CELL),
|
np.argwhere(level_array == c.FREE_CELL),
|
||||||
self._level_shape
|
self._level_shape
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.FLOOR: floor})
|
self._entities.add_additional_items({c.FLOOR: floor})
|
||||||
|
|
||||||
# NOPOS
|
# NOPOS
|
||||||
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||||
@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
|
|||||||
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
||||||
entity_kwargs=dict(context=floor)
|
entity_kwargs=dict(context=floor)
|
||||||
)
|
)
|
||||||
self._entities.register_additional_items({c.DOORS: doors})
|
self._entities.add_additional_items({c.DOORS: doors})
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||||
if additional_actions := self.actions_hook:
|
if additional_actions := self.actions_hook:
|
||||||
self._actions.register_additional_items(additional_actions)
|
self._actions.add_additional_items(additional_actions)
|
||||||
|
|
||||||
# Agents
|
# Agents
|
||||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||||
@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
|
|||||||
if self._injected_agents:
|
if self._injected_agents:
|
||||||
initialized_injections = list()
|
initialized_injections = list()
|
||||||
for i, injection in enumerate(self._injected_agents):
|
for i, injection in enumerate(self._injected_agents):
|
||||||
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||||
initialized_injections.append(agents[-1])
|
initialized_injections.append(agents[-1])
|
||||||
self._initialized_injections = initialized_injections
|
self._initialized_injections = initialized_injections
|
||||||
self._entities.register_additional_items({c.AGENT: agents})
|
self._entities.add_additional_items({c.AGENT: agents})
|
||||||
|
|
||||||
if self.obs_prop.additional_agent_placeholder is not None:
|
if self.obs_prop.additional_agent_placeholder is not None:
|
||||||
# TODO: Make this accept Lists for multiple placeholders
|
# TODO: Make this accept Lists for multiple placeholders
|
||||||
@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
|
|||||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||||
|
|
||||||
# Additional Entitites from SubEnvs
|
# Additional Entitites from SubEnvs
|
||||||
if additional_entities := self.entities_hook:
|
if additional_entities := self.entities_hook:
|
||||||
self._entities.register_additional_items(additional_entities)
|
self._entities.add_additional_items(additional_entities)
|
||||||
|
|
||||||
if self.obs_prop.show_global_position_info:
|
if self.obs_prop.show_global_position_info:
|
||||||
global_positions = GlobalPositions(self._level_shape)
|
global_positions = GlobalPositions(self._level_shape)
|
||||||
# This moved into the GlobalPosition object
|
# This moved into the GlobalPosition object
|
||||||
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||||
global_positions.spawn_global_position_objects(self[c.AGENT])
|
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||||
|
|
||||||
# Return
|
# Return
|
||||||
return self._entities
|
return self._entities
|
||||||
@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||||
# Actions
|
# Actions
|
||||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
|
||||||
x_new = agent.x + x_diff
|
x_new = agent.x + x_diff
|
||||||
y_new = agent.y + y_diff
|
y_new = agent.y + y_diff
|
||||||
|
|
||||||
|
@ -72,15 +72,15 @@ class EnvObject(Object):
|
|||||||
def encoding(self):
|
def encoding(self):
|
||||||
return c.OCCUPIED_CELL
|
return c.OCCUPIED_CELL
|
||||||
|
|
||||||
def __init__(self, register, **kwargs):
|
def __init__(self, collection, **kwargs):
|
||||||
super(EnvObject, self).__init__(**kwargs)
|
super(EnvObject, self).__init__(**kwargs)
|
||||||
self._register = register
|
self._collection = collection
|
||||||
|
|
||||||
def change_register(self, register):
|
def change_parent_collection(self, other_collection):
|
||||||
register.register_item(self)
|
other_collection.add_item(self)
|
||||||
self._register.delete_env_object(self)
|
self._collection.delete_env_object(self)
|
||||||
self._register = register
|
self._collection = other_collection
|
||||||
return self._register == register
|
return self._collection == other_collection
|
||||||
# With Rendering
|
# With Rendering
|
||||||
|
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ class MoveableEntity(Entity):
|
|||||||
curr_tile.leave(self)
|
curr_tile.leave(self)
|
||||||
self._tile = next_tile
|
self._tile = next_tile
|
||||||
self._last_tile = curr_tile
|
self._last_tile = curr_tile
|
||||||
self._register.notify_change_to_value(self)
|
self._collection.notify_change_to_value(self)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
@ -371,13 +371,13 @@ class Door(Entity):
|
|||||||
def _open(self):
|
def _open(self):
|
||||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||||
self._state = c.OPEN_DOOR
|
self._state = c.OPEN_DOOR
|
||||||
self._register.notify_change_to_value(self)
|
self._collection.notify_change_to_value(self)
|
||||||
self.time_to_close = self.auto_close_interval
|
self.time_to_close = self.auto_close_interval
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self.connectivity.remove_node(self.pos)
|
self.connectivity.remove_node(self.pos)
|
||||||
self._state = c.CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
self._register.notify_change_to_value(self)
|
self._collection.notify_change_to_value(self)
|
||||||
|
|
||||||
def is_linked(self, old_pos, new_pos):
|
def is_linked(self, old_pos, new_pos):
|
||||||
try:
|
try:
|
||||||
|
@ -13,11 +13,11 @@ from environments import helpers as h
|
|||||||
from environments.helpers import Constants as c
|
from environments.helpers import Constants as c
|
||||||
|
|
||||||
##########################################################################
|
##########################################################################
|
||||||
# ##################### Base Register Definition ####################### #
|
# ################## Base Collections Definition ####################### #
|
||||||
##########################################################################
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
class ObjectRegister:
|
class ObjectCollection:
|
||||||
_accepted_objects = Object
|
_accepted_objects = Object
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -25,59 +25,59 @@ class ObjectRegister:
|
|||||||
return f'{self.__class__.__name__}'
|
return f'{self.__class__.__name__}'
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._register = dict()
|
self._collection = dict()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._register)
|
return len(self._collection)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.values())
|
return iter(self.values())
|
||||||
|
|
||||||
def register_item(self, other: _accepted_objects):
|
def add_item(self, other: _accepted_objects):
|
||||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||||
f'{self._accepted_objects}, ' \
|
f'{self._accepted_objects}, ' \
|
||||||
f'but were {other.__class__}.,'
|
f'but were {other.__class__}.,'
|
||||||
self._register.update({other.name: other})
|
self._collection.update({other.name: other})
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def register_additional_items(self, others: List[_accepted_objects]):
|
def add_additional_items(self, others: List[_accepted_objects]):
|
||||||
for other in others:
|
for other in others:
|
||||||
self.register_item(other)
|
self.add_item(other)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self._register.keys()
|
return self._collection.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
return self._register.values()
|
return self._collection.values()
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self._register.items()
|
return self._collection.items()
|
||||||
|
|
||||||
def _get_index(self, item):
|
def _get_index(self, item):
|
||||||
try:
|
try:
|
||||||
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
return next(i for i, v in enumerate(self._collection.values()) if v == item)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
if isinstance(item, (int, np.int64, np.int32)):
|
if isinstance(item, (int, np.int64, np.int32)):
|
||||||
if item < 0:
|
if item < 0:
|
||||||
item = len(self._register) - abs(item)
|
item = len(self._collection) - abs(item)
|
||||||
try:
|
try:
|
||||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
return next(v for i, v in enumerate(self._collection.values()) if i == item)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return self._register[item]
|
return self._collection[item]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}[{self._register}]'
|
return f'{self.__class__.__name__}[{self._collection}]'
|
||||||
|
|
||||||
|
|
||||||
class EnvObjectRegister(ObjectRegister):
|
class EnvObjectCollection(ObjectCollection):
|
||||||
|
|
||||||
_accepted_objects = EnvObject
|
_accepted_objects = EnvObject
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
|
|||||||
is_blocking_light: bool = False,
|
is_blocking_light: bool = False,
|
||||||
can_collide: bool = False,
|
can_collide: bool = False,
|
||||||
can_be_shadowed: bool = True, **kwargs):
|
can_be_shadowed: bool = True, **kwargs):
|
||||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
super(EnvObjectCollection, self).__init__(*args, **kwargs)
|
||||||
self._shape = obs_shape
|
self._shape = obs_shape
|
||||||
self._array = None
|
self._array = None
|
||||||
self._individual_slices = individual_slices
|
self._individual_slices = individual_slices
|
||||||
@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
|
|||||||
self.can_be_shadowed = can_be_shadowed
|
self.can_be_shadowed = can_be_shadowed
|
||||||
self.can_collide = can_collide
|
self.can_collide = can_collide
|
||||||
|
|
||||||
def register_item(self, other: EnvObject):
|
def add_item(self, other: EnvObject):
|
||||||
super(EnvObjectRegister, self).register_item(other)
|
super(EnvObjectCollection, self).add_item(other)
|
||||||
if self._array is None:
|
if self._array is None:
|
||||||
self._array = np.zeros((1, *self._shape))
|
self._array = np.zeros((1, *self._shape))
|
||||||
else:
|
else:
|
||||||
@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
|
|||||||
if self._individual_slices:
|
if self._individual_slices:
|
||||||
self._array = np.delete(self._array, idx, axis=0)
|
self._array = np.delete(self._array, idx, axis=0)
|
||||||
else:
|
else:
|
||||||
self.notify_change_to_free(self._register[name])
|
self.notify_change_to_free(self._collection[name])
|
||||||
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
||||||
# in the observation array are result of enumeration. They can overide each other.
|
# in the observation array are result of enumeration. They can overide each other.
|
||||||
# Todo: Find a better solution
|
# Todo: Find a better solution
|
||||||
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
|
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
|
||||||
self._refresh_arrays()
|
self._refresh_arrays()
|
||||||
del self._register[name]
|
del self._collection[name]
|
||||||
|
|
||||||
def delete_env_object(self, env_object: EnvObject):
|
def delete_env_object(self, env_object: EnvObject):
|
||||||
del self[env_object.name]
|
del self[env_object.name]
|
||||||
@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
|
|||||||
del self[name]
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
class EntityRegister(EnvObjectRegister, ABC):
|
class EntityCollection(EnvObjectCollection, ABC):
|
||||||
|
|
||||||
_accepted_objects = Entity
|
_accepted_objects = Entity
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
register_obj = cls(*args, **kwargs)
|
collection = cls(*args, **kwargs)
|
||||||
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
entities = [cls._accepted_objects(tile, collection, str_ident=i,
|
||||||
**entity_kwargs if entity_kwargs is not None else {})
|
**entity_kwargs if entity_kwargs is not None else {})
|
||||||
for i, tile in enumerate(tiles)]
|
for i, tile in enumerate(tiles)]
|
||||||
register_obj.register_additional_items(entities)
|
collection.add_additional_items(entities)
|
||||||
return register_obj
|
return collection
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||||
@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
|
|||||||
return [entity.tile for entity in self]
|
return [entity.tile for entity in self]
|
||||||
|
|
||||||
def __init__(self, level_shape, *args, **kwargs):
|
def __init__(self, level_shape, *args, **kwargs):
|
||||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
|
||||||
self._lazy_eval_transforms = []
|
self._lazy_eval_transforms = []
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def __delitem__(self, name):
|
||||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||||
obj.tile.leave(obj)
|
obj.tile.leave(obj)
|
||||||
super(EntityRegister, self).__delitem__(name)
|
super(EntityCollection, self).__delitem__(name)
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
if self._lazy_eval_transforms:
|
if self._lazy_eval_transforms:
|
||||||
@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
class BoundEnvObjCollection(EnvObjectCollection, ABC):
|
||||||
|
|
||||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
|||||||
return self._array[self.idx_by_entity(entity)]
|
return self._array[self.idx_by_entity(entity)]
|
||||||
|
|
||||||
|
|
||||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
class MovingEntityObjectCollection(EntityCollection, ABC):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def notify_change_to_value(self, entity):
|
def notify_change_to_value(self, entity):
|
||||||
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
|
||||||
if entity.last_pos != c.NO_POS:
|
if entity.last_pos != c.NO_POS:
|
||||||
try:
|
try:
|
||||||
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
||||||
@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
|
|||||||
|
|
||||||
|
|
||||||
##########################################################################
|
##########################################################################
|
||||||
# ################# Objects and Entity Registers ####################### #
|
# ################# Objects and Entity Collection ###################### #
|
||||||
##########################################################################
|
##########################################################################
|
||||||
|
|
||||||
|
|
||||||
class GlobalPositions(EnvObjectRegister):
|
class GlobalPositions(EnvObjectCollection):
|
||||||
|
|
||||||
_accepted_objects = GlobalPosition
|
_accepted_objects = GlobalPosition
|
||||||
|
|
||||||
@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
|
|||||||
global_positions = [self._accepted_objects(self._shape, agent, self)
|
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.register_additional_items(global_positions)
|
self.add_additional_items(global_positions)
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
return {}
|
return {}
|
||||||
@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PlaceHolders(EnvObjectRegister):
|
class PlaceHolders(EnvObjectCollection):
|
||||||
_accepted_objects = PlaceHolder
|
_accepted_objects = PlaceHolder
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
|
|||||||
# objects_name = cls._accepted_objects.__name__
|
# objects_name = cls._accepted_objects.__name__
|
||||||
if isinstance(values, (str, numbers.Number)):
|
if isinstance(values, (str, numbers.Number)):
|
||||||
values = [values]
|
values = [values]
|
||||||
register_obj = cls(*args, **kwargs)
|
collection = cls(*args, **kwargs)
|
||||||
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
|
||||||
**object_kwargs if object_kwargs is not None else {})
|
**object_kwargs if object_kwargs is not None else {})
|
||||||
for i, value in enumerate(values)]
|
for i, value in enumerate(values)]
|
||||||
register_obj.register_additional_items(objects)
|
collection.add_additional_items(objects)
|
||||||
return register_obj
|
return collection
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
|
|||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
|
|
||||||
class Entities(ObjectRegister):
|
class Entities(ObjectCollection):
|
||||||
_accepted_objects = EntityRegister
|
_accepted_objects = EntityCollection
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def arrays(self):
|
def arrays(self):
|
||||||
@ -352,7 +352,7 @@ class Entities(ObjectRegister):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
return list(self._register.keys())
|
return list(self._collection.keys())
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Entities, self).__init__()
|
super(Entities, self).__init__()
|
||||||
@ -360,21 +360,21 @@ class Entities(ObjectRegister):
|
|||||||
def iter_individual_entitites(self):
|
def iter_individual_entitites(self):
|
||||||
return iter((x for sublist in self.values() for x in sublist))
|
return iter((x for sublist in self.values() for x in sublist))
|
||||||
|
|
||||||
def register_item(self, other: dict):
|
def add_item(self, other: dict):
|
||||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||||
"This group of entities has already been registered!"
|
"This group of entities has already been added!"
|
||||||
self._register.update(other)
|
self._collection.update(other)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def register_additional_items(self, others: Dict):
|
def add_additional_items(self, others: Dict):
|
||||||
return self.register_item(others)
|
return self.add_item(others)
|
||||||
|
|
||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||||
return found_entities
|
return found_entities
|
||||||
|
|
||||||
|
|
||||||
class Walls(EntityRegister):
|
class Walls(EntityCollection):
|
||||||
_accepted_objects = Wall
|
_accepted_objects = Wall
|
||||||
|
|
||||||
def as_array(self):
|
def as_array(self):
|
||||||
@ -396,7 +396,7 @@ class Walls(EntityRegister):
|
|||||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||||
tiles = cls(*args, **kwargs)
|
tiles = cls(*args, **kwargs)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
tiles.register_additional_items(
|
tiles.add_additional_items(
|
||||||
[cls._accepted_objects(pos, tiles)
|
[cls._accepted_objects(pos, tiles)
|
||||||
for pos in argwhere_coordinates]
|
for pos in argwhere_coordinates]
|
||||||
)
|
)
|
||||||
@ -441,7 +441,7 @@ class Floors(Walls):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Agents(MovingEntityObjectRegister):
|
class Agents(MovingEntityObjectCollection):
|
||||||
_accepted_objects = Agent
|
_accepted_objects = Agent
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
|
|||||||
old_agent = self[key]
|
old_agent = self[key]
|
||||||
self[key].tile.leave(self[key])
|
self[key].tile.leave(self[key])
|
||||||
agent._name = old_agent.name
|
agent._name = old_agent.name
|
||||||
self._register[agent.name] = agent
|
self._collection[agent.name] = agent
|
||||||
|
|
||||||
|
|
||||||
class Doors(EntityRegister):
|
class Doors(EntityCollection):
|
||||||
|
|
||||||
def __init__(self, *args, have_area: bool = False, **kwargs):
|
def __init__(self, *args, have_area: bool = False, **kwargs):
|
||||||
self.have_area = have_area
|
self.have_area = have_area
|
||||||
@ -490,7 +490,7 @@ class Doors(EntityRegister):
|
|||||||
return super(Doors, self).as_array()
|
return super(Doors, self).as_array()
|
||||||
|
|
||||||
|
|
||||||
class Actions(ObjectRegister):
|
class Actions(ObjectCollection):
|
||||||
_accepted_objects = Action
|
_accepted_objects = Action
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -507,22 +507,22 @@ class Actions(ObjectRegister):
|
|||||||
|
|
||||||
# Move this to Baseclass, Env init?
|
# Move this to Baseclass, Env init?
|
||||||
if self.allow_square_movement:
|
if self.allow_square_movement:
|
||||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||||
for direction in h.EnvActions.square_move()])
|
for direction in h.EnvActions.square_move()])
|
||||||
if self.allow_diagonal_movement:
|
if self.allow_diagonal_movement:
|
||||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||||
for direction in h.EnvActions.diagonal_move()])
|
for direction in h.EnvActions.diagonal_move()])
|
||||||
self._movement_actions = self._register.copy()
|
self._movement_actions = self._collection.copy()
|
||||||
if self.can_use_doors:
|
if self.can_use_doors:
|
||||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||||
if self.allow_no_op:
|
if self.allow_no_op:
|
||||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||||
|
|
||||||
def is_moving_action(self, action: Union[int]):
|
def is_moving_action(self, action: Union[int]):
|
||||||
return action in self.movement_actions.values()
|
return action in self.movement_actions.values()
|
||||||
|
|
||||||
|
|
||||||
class Zones(ObjectRegister):
|
class Zones(ObjectCollection):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accounting_zones(self):
|
def accounting_zones(self):
|
||||||
@ -551,5 +551,5 @@ class Zones(ObjectRegister):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self._zone_slices[item]
|
return self._zone_slices[item]
|
||||||
|
|
||||||
def register_additional_items(self, other: Union[str, List[str]]):
|
def add_additional_items(self, other: Union[str, List[str]]):
|
||||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||||
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject):
|
|||||||
if self.charge_level != 0:
|
if self.charge_level != 0:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.charge_level = max(0, amount + self.charge_level)
|
self.charge_level = max(0, amount + self.charge_level)
|
||||||
self._register.notify_change_to_value(self)
|
self._collection.notify_change_to_value(self)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject):
|
|||||||
return attr_dict
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
class BatteriesRegister(EnvObjectRegister):
|
class BatteriesRegister(EnvObjectCollection):
|
||||||
|
|
||||||
_accepted_objects = Battery
|
_accepted_objects = Battery
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister):
|
|||||||
|
|
||||||
def spawn_batteries(self, agents, initial_charge_level):
|
def spawn_batteries(self, agents, initial_charge_level):
|
||||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||||
self.register_additional_items(batteries)
|
self.add_additional_items(batteries)
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
# as dict with additional nesting
|
# as dict with additional nesting
|
||||||
@ -140,7 +140,7 @@ class ChargePod(Entity):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(EntityRegister):
|
class ChargePods(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = ChargePod
|
_accepted_objects = ChargePod
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory
|
|||||||
from environments.helpers import Constants as BaseConstants
|
from environments.helpers import Constants as BaseConstants
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments.factory.base.objects import Agent, Entity, Action
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.base.registers import Entities, EntityRegister
|
from environments.factory.base.registers import Entities, EntityCollection
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ class Destination(Entity):
|
|||||||
return state_summary
|
return state_summary
|
||||||
|
|
||||||
|
|
||||||
class Destinations(EntityRegister):
|
class Destinations(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = Destination
|
_accepted_objects = Destination
|
||||||
|
|
||||||
@ -208,13 +208,13 @@ class DestFactory(BaseFactory):
|
|||||||
n_dest_to_spawn = len(destinations_to_spawn)
|
n_dest_to_spawn = len(destinations_to_spawn)
|
||||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DEST].register_additional_items(destinations)
|
self[c.DEST].add_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
del self._dest_spawn_timer[dest]
|
del self._dest_spawn_timer[dest]
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DEST].register_additional_items(destinations)
|
self[c.DEST].add_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
del self._dest_spawn_timer[dest]
|
del self._dest_spawn_timer[dest]
|
||||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||||
@ -231,7 +231,7 @@ class DestFactory(BaseFactory):
|
|||||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||||
for dest in list(self[c.DEST].values()):
|
for dest in list(self[c.DEST].values()):
|
||||||
if dest.is_considered_reached:
|
if dest.is_considered_reached:
|
||||||
dest.change_register(self[c.DEST])
|
dest.change_parent_collection(self[c.DEST])
|
||||||
self._dest_spawn_timer[dest.name] = 0
|
self._dest_spawn_timer[dest.name] = 0
|
||||||
self.print(f'{dest.name} is reached now, removing...')
|
self.print(f'{dest.name} is reached now, removing...')
|
||||||
else:
|
else:
|
||||||
|
@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
|
|||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||||
from environments.factory.base.registers import Entities, EntityRegister
|
from environments.factory.base.registers import Entities, EntityCollection
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.utility_classes import ObservationProperties
|
from environments.utility_classes import ObservationProperties
|
||||||
@ -61,7 +61,7 @@ class Dirt(Entity):
|
|||||||
|
|
||||||
def set_new_amount(self, amount):
|
def set_new_amount(self, amount):
|
||||||
self._amount = amount
|
self._amount = amount
|
||||||
self._register.notify_change_to_value(self)
|
self._collection.notify_change_to_value(self)
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
def summarize_state(self, **kwargs):
|
||||||
state_dict = super().summarize_state(**kwargs)
|
state_dict = super().summarize_state(**kwargs)
|
||||||
@ -69,7 +69,7 @@ class Dirt(Entity):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(EntityRegister):
|
class DirtRegister(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = Dirt
|
_accepted_objects = Dirt
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ class DirtRegister(EntityRegister):
|
|||||||
dirt = self.by_pos(tile.pos)
|
dirt = self.by_pos(tile.pos)
|
||||||
if dirt is None:
|
if dirt is None:
|
||||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||||
self.register_item(dirt)
|
self.add_item(dirt)
|
||||||
else:
|
else:
|
||||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||||
|
@ -5,10 +5,10 @@ import numpy as np
|
|||||||
from environments.factory.base.objects import Agent, Entity, Action
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
||||||
from environments.factory.base.objects import Floor
|
from environments.factory.base.objects import Floor
|
||||||
from environments.factory.base.registers import Floors, Entities, EntityRegister
|
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||||
|
|
||||||
|
|
||||||
class Machines(EntityRegister):
|
class Machines(EntityCollection):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
|
|||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.helpers import EnvActions as BaseActions
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
@ -53,17 +53,17 @@ class Item(Entity):
|
|||||||
self._auto_despawn = auto_despawn
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
def set_tile_to(self, no_pos_tile):
|
def set_tile_to(self, no_pos_tile):
|
||||||
assert self._register.__class__.__name__ != ItemRegister.__class__
|
assert self._collection.__class__.__name__ != ItemRegister.__class__
|
||||||
self._tile = no_pos_tile
|
self._tile = no_pos_tile
|
||||||
|
|
||||||
|
|
||||||
class ItemRegister(EntityRegister):
|
class ItemRegister(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = Item
|
_accepted_objects = Item
|
||||||
|
|
||||||
def spawn_items(self, tiles: List[Floor]):
|
def spawn_items(self, tiles: List[Floor]):
|
||||||
items = [Item(tile, self) for tile in tiles]
|
items = [Item(tile, self) for tile in tiles]
|
||||||
self.register_additional_items(items)
|
self.add_additional_items(items)
|
||||||
|
|
||||||
def despawn_items(self, items: List[Item]):
|
def despawn_items(self, items: List[Item]):
|
||||||
items = [items] if isinstance(items, Item) else items
|
items = [items] if isinstance(items, Item) else items
|
||||||
@ -71,7 +71,7 @@ class ItemRegister(EntityRegister):
|
|||||||
del self[item]
|
del self[item]
|
||||||
|
|
||||||
|
|
||||||
class Inventory(BoundEnvObjRegister):
|
class Inventory(BoundEnvObjCollection):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister):
|
|||||||
return item_to_pop
|
return item_to_pop
|
||||||
|
|
||||||
|
|
||||||
class Inventories(ObjectRegister):
|
class Inventories(ObjectCollection):
|
||||||
|
|
||||||
_accepted_objects = Inventory
|
_accepted_objects = Inventory
|
||||||
is_blocking_light = False
|
is_blocking_light = False
|
||||||
@ -114,7 +114,7 @@ class Inventories(ObjectRegister):
|
|||||||
def spawn_inventories(self, agents, capacity):
|
def spawn_inventories(self, agents, capacity):
|
||||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
self.register_additional_items(inventories)
|
self.add_additional_items(inventories)
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
def idx_by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
@ -161,7 +161,7 @@ class DropOffLocation(Entity):
|
|||||||
return super().summarize_state(n_steps=n_steps)
|
return super().summarize_state(n_steps=n_steps)
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityRegister):
|
class DropOffLocations(EntityCollection):
|
||||||
|
|
||||||
_accepted_objects = DropOffLocation
|
_accepted_objects = DropOffLocation
|
||||||
|
|
||||||
@ -250,7 +250,7 @@ class ItemFactory(BaseFactory):
|
|||||||
reason=a.ITEM_ACTION, info=info_dict)
|
reason=a.ITEM_ACTION, info=info_dict)
|
||||||
return valid, reward
|
return valid, reward
|
||||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||||
item.change_register(inventory)
|
item.change_parent_collection(inventory)
|
||||||
item.set_tile_to(self._NO_POS_TILE)
|
item.set_tile_to(self._NO_POS_TILE)
|
||||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||||
|
@ -7,47 +7,76 @@ import numpy as np
|
|||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
|
||||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
|
||||||
|
|
||||||
LEVELS_DIR = 'levels'
|
"""
|
||||||
STEPS_START = 1
|
This file is used for:
|
||||||
|
1. string based definition
|
||||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||||
'dirty_tile_count', 'terminal_observation', 'episode']
|
When defining new envs, use class inheritance.
|
||||||
|
|
||||||
|
2. utility function definition
|
||||||
|
There are static utility functions which are not bound to a specific environment.
|
||||||
|
In this file they are defined to be used across the entire package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments
|
||||||
|
|
||||||
|
|
||||||
|
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||||
|
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||||
|
|
||||||
|
# Not used anymore? Clean!
|
||||||
|
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||||
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||||
|
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||||
|
'episode']
|
||||||
|
|
||||||
|
|
||||||
# Constants
|
|
||||||
class Constants:
|
class Constants:
|
||||||
WALL = '#'
|
|
||||||
WALLS = 'Walls'
|
|
||||||
FLOOR = 'Floor'
|
|
||||||
DOOR = 'D'
|
|
||||||
DANGER_ZONE = 'x'
|
|
||||||
LEVEL = 'Level'
|
|
||||||
AGENT = 'Agent'
|
|
||||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
|
||||||
GLOBAL_POSITION = 'GLOBAL_POSITION'
|
|
||||||
FREE_CELL = 0
|
|
||||||
OCCUPIED_CELL = 1
|
|
||||||
SHADOWED_CELL = -1
|
|
||||||
ACCESS_DOOR_CELL = 1/3
|
|
||||||
OPEN_DOOR_CELL = 2/3
|
|
||||||
CLOSED_DOOR_CELL = 3/3
|
|
||||||
NO_POS = (-9999, -9999)
|
|
||||||
|
|
||||||
DOORS = 'Doors'
|
"""
|
||||||
CLOSED_DOOR = 'closed'
|
String based mapping. Use these to handle keys or define values, which can be then be used globaly.
|
||||||
OPEN_DOOR = 'open'
|
Please use class inheritance when defining new environments.
|
||||||
ACCESS_DOOR = 'access'
|
"""
|
||||||
|
|
||||||
ACTION = 'action'
|
WALL = '#' # Wall tile identifier for resolving the string based map files.
|
||||||
COLLISION = 'collision'
|
DOOR = 'D' # Door identifier for resolving the string based map files.
|
||||||
VALID = True
|
DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files.
|
||||||
NOT_VALID = False
|
|
||||||
|
WALLS = 'Walls' # Identifier of Wall-objects and sets (collections).
|
||||||
|
FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections).
|
||||||
|
DOORS = 'Doors' # Identifier of Door-objects and sets (collections).
|
||||||
|
LEVEL = 'Level' # Identifier of Level-objects and sets (collections).
|
||||||
|
AGENT = 'Agent' # Identifier of Agent-objects and sets (collections).
|
||||||
|
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections).
|
||||||
|
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||||
|
|
||||||
|
FREE_CELL = 0 # Free-Cell value used in observation
|
||||||
|
OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||||
|
SHADOWED_CELL = -1 # Shadowed-Cell value used in observation
|
||||||
|
ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation
|
||||||
|
OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation
|
||||||
|
CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation
|
||||||
|
|
||||||
|
NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid)
|
||||||
|
|
||||||
|
CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state
|
||||||
|
OPEN_DOOR = 'open' # Identifier to compare door-is-open state
|
||||||
|
# ACCESS_DOOR = 'access' # Identifier to compare access positions
|
||||||
|
|
||||||
|
ACTION = 'action' # Identifier of Action-objects and sets (collections).
|
||||||
|
COLLISION = 'collision' # Identifier to use in the context of collitions.
|
||||||
|
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||||
|
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||||
|
|
||||||
|
|
||||||
class EnvActions:
|
class EnvActions:
|
||||||
|
"""
|
||||||
|
String based mapping. Use these to identifiy actions, can be used globaly.
|
||||||
|
Please use class inheritance when defining new environments with new actions.
|
||||||
|
"""
|
||||||
# Movements
|
# Movements
|
||||||
NORTH = 'north'
|
NORTH = 'north'
|
||||||
EAST = 'east'
|
EAST = 'east'
|
||||||
@ -63,24 +92,77 @@ class EnvActions:
|
|||||||
NOOP = 'no_op'
|
NOOP = 'no_op'
|
||||||
USE_DOOR = 'use_door'
|
USE_DOOR = 'use_door'
|
||||||
|
|
||||||
|
_ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||||
|
{NORTH: (-1, 0), NORTHEAST: (-1, 1),
|
||||||
|
EAST: (0, 1), SOUTHEAST: (1, 1),
|
||||||
|
SOUTH: (1, 0), SOUTHWEST: (1, -1),
|
||||||
|
WEST: (0, -1), NORTHWEST: (-1, -1)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_move(cls, other):
|
def is_move(cls, action):
|
||||||
return any([other == direction for direction in cls.movement_actions()])
|
"""
|
||||||
|
Classmethod; checks if given action is a movement action or not. Depending on the env. configuration,
|
||||||
|
Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal.
|
||||||
|
|
||||||
|
:param action: Action to be checked
|
||||||
|
:type action: str
|
||||||
|
:return: Whether the given action is a movement action.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return any([action == direction for direction in cls.movement_actions()])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def square_move(cls):
|
def square_move(cls):
|
||||||
|
"""
|
||||||
|
Classmethod; return a list of movement actions that are considered square or `manhattan` style movements.
|
||||||
|
|
||||||
|
:return: A list of movement actions.
|
||||||
|
:rtype: list(str)
|
||||||
|
"""
|
||||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def diagonal_move(cls):
|
def diagonal_move(cls):
|
||||||
|
"""
|
||||||
|
Classmethod; return a list of movement actions that are considered diagonal movements.
|
||||||
|
|
||||||
|
:return: A list of movement actions.
|
||||||
|
:rtype: list(str)
|
||||||
|
"""
|
||||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def movement_actions(cls):
|
def movement_actions(cls):
|
||||||
|
"""
|
||||||
|
Classmethod; return a list of all available movement actions.
|
||||||
|
Please note, that this is indipendent from the env. properties
|
||||||
|
|
||||||
|
:return: A list of movement actions.
|
||||||
|
:rtype: list(str)
|
||||||
|
"""
|
||||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve_movement_action_to_coords(cls, action):
|
||||||
|
"""
|
||||||
|
Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for.
|
||||||
|
How does the current entity coordinate change if it performs the given action?
|
||||||
|
Please note, this is indipendent from the env. properties
|
||||||
|
|
||||||
|
:return: Delta coorinates.
|
||||||
|
:rtype: tuple(int, int)
|
||||||
|
"""
|
||||||
|
return cls._ACTIONMAP[action]
|
||||||
|
|
||||||
|
|
||||||
class RewardsBase(NamedTuple):
|
class RewardsBase(NamedTuple):
|
||||||
|
"""
|
||||||
|
Value based mapping. Use these to define reward values for specific conditions (i.e. the action
|
||||||
|
in a given context), can be used globaly.
|
||||||
|
Please use class inheritance when defining new environments with new rewards.
|
||||||
|
"""
|
||||||
MOVEMENTS_VALID: float = -0.001
|
MOVEMENTS_VALID: float = -0.001
|
||||||
MOVEMENTS_FAIL: float = -0.05
|
MOVEMENTS_FAIL: float = -0.05
|
||||||
NOOP: float = -0.01
|
NOOP: float = -0.01
|
||||||
@ -89,23 +171,31 @@ class RewardsBase(NamedTuple):
|
|||||||
COLLISION: float = -0.5
|
COLLISION: float = -0.5
|
||||||
|
|
||||||
|
|
||||||
m = EnvActions
|
|
||||||
c = Constants
|
|
||||||
|
|
||||||
ACTIONMAP = defaultdict(lambda: (0, 0),
|
|
||||||
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
|
|
||||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
|
||||||
m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1),
|
|
||||||
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ObservationTranslator:
|
class ObservationTranslator:
|
||||||
|
|
||||||
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
||||||
*per_agent_named_obs_space: Dict[str, dict],
|
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||||
placeholder_fill_value: Union[int, str] = 'N'):
|
placeholder_fill_value: Union[int, str] = 'N'):
|
||||||
|
"""
|
||||||
|
This is a helper class, which converts agents observations from joined environments.
|
||||||
|
For example, agents trained in different environments may expect different observations.
|
||||||
|
This class translates from larger observations spaces to smaller.
|
||||||
|
A string identifier based approach is used.
|
||||||
|
Currently, it is not possible to mix different obs shapes.
|
||||||
|
|
||||||
|
:param obs_shape_2d: The shape of the observation the agents expect.
|
||||||
|
:type obs_shape_2d: tuple(int, int)
|
||||||
|
|
||||||
|
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||||
|
:type this_named_observation_space: Dict[str, dict]
|
||||||
|
|
||||||
|
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||||
|
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||||
|
|
||||||
|
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||||
|
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||||
|
"""
|
||||||
|
|
||||||
assert len(obs_shape_2d) == 2
|
assert len(obs_shape_2d) == 2
|
||||||
self.obs_shape = obs_shape_2d
|
self.obs_shape = obs_shape_2d
|
||||||
if isinstance(placeholder_fill_value, str):
|
if isinstance(placeholder_fill_value, str):
|
||||||
@ -119,7 +209,7 @@ class ObservationTranslator:
|
|||||||
self.random_fill = None
|
self.random_fill = None
|
||||||
|
|
||||||
self._this_named_obs_space = this_named_observation_space
|
self._this_named_obs_space = this_named_observation_space
|
||||||
self._per_agent_named_obs_space = list(per_agent_named_obs_space)
|
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||||
|
|
||||||
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||||
@ -137,6 +227,19 @@ class ObservationTranslator:
|
|||||||
class ActionTranslator:
|
class ActionTranslator:
|
||||||
|
|
||||||
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||||
|
"""
|
||||||
|
This is a helper class, which converts agents action spaces to a joined environments action space.
|
||||||
|
For example, agents trained in different environments may have different action spaces.
|
||||||
|
This class translates from smaller individual agent action spaces to larger joined spaces.
|
||||||
|
A string identifier based approach is used.
|
||||||
|
|
||||||
|
:param target_named_action_space: Joined `Named action space` for the current environment.
|
||||||
|
:type target_named_action_space: Dict[str, dict]
|
||||||
|
|
||||||
|
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
||||||
|
:type per_agent_named_action_space: Dict[str, dict]
|
||||||
|
"""
|
||||||
|
|
||||||
self._target_named_action_space = target_named_action_space
|
self._target_named_action_space = target_named_action_space
|
||||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||||
@ -155,6 +258,16 @@ class ActionTranslator:
|
|||||||
|
|
||||||
# Utility functions
|
# Utility functions
|
||||||
def parse_level(path):
|
def parse_level(path):
|
||||||
|
"""
|
||||||
|
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
||||||
|
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
||||||
|
|
||||||
|
:param path: Path to the `level` or `map` file on harddrive.
|
||||||
|
:type path: os.Pathlike
|
||||||
|
|
||||||
|
:return: The read string representation of the `level` or `map`
|
||||||
|
:rtype: List[List[str]]
|
||||||
|
"""
|
||||||
with path.open('r') as lvl:
|
with path.open('r') as lvl:
|
||||||
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
||||||
if len(set([len(line) for line in level])) > 1:
|
if len(set([len(line) for line in level])) > 1:
|
||||||
@ -162,29 +275,56 @@ def parse_level(path):
|
|||||||
return level
|
return level
|
||||||
|
|
||||||
|
|
||||||
def one_hot_level(level, wall_char: str = c.WALL):
|
def one_hot_level(level, wall_char: str = Constants.WALL):
|
||||||
|
"""
|
||||||
|
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
||||||
|
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
||||||
|
Can be changed to filter for any symbol.
|
||||||
|
|
||||||
|
:param level: String based level representation (list of lists, see function `parse_level`).
|
||||||
|
:param wall_char: List[List[str]]
|
||||||
|
|
||||||
|
:return: Binary numpy array
|
||||||
|
:rtype: np.typing._array_like.ArrayLike
|
||||||
|
"""
|
||||||
|
|
||||||
grid = np.array(level)
|
grid = np.array(level)
|
||||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL
|
binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL
|
||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Given a slice (2-D Arraylike object)
|
||||||
|
|
||||||
|
:param slice_to_check_against: The slice to check for accessability
|
||||||
|
:type slice_to_check_against: np.typing._array_like.ArrayLike
|
||||||
|
|
||||||
|
:param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys.
|
||||||
|
:type position_to_check: tuple(int, int)
|
||||||
|
|
||||||
|
:return: Whether a position can be moved to.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
x_pos, y_pos = position_to_check
|
x_pos, y_pos = position_to_check
|
||||||
|
|
||||||
# Check if agent colides with grid boundrys
|
# Check if agent colides with grid boundrys
|
||||||
valid = not (
|
valid = not (
|
||||||
x_pos < 0 or y_pos < 0
|
x_pos < 0 or y_pos < 0
|
||||||
or x_pos >= slice_to_check_against.shape[0]
|
or x_pos >= slice_to_check_against.shape[0]
|
||||||
or y_pos >= slice_to_check_against.shape[0]
|
or y_pos >= slice_to_check_against.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for collision with level walls
|
# Check for collision with level walls
|
||||||
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
||||||
return c.VALID if valid else c.NOT_VALID
|
return Constants.VALID if valid else Constants.NOT_VALID
|
||||||
|
|
||||||
|
|
||||||
def asset_str(agent):
|
def asset_str(agent):
|
||||||
|
"""
|
||||||
|
FIXME @ romue
|
||||||
|
"""
|
||||||
# What does this abonimation do?
|
# What does this abonimation do?
|
||||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||||
# print('error')
|
# print('error')
|
||||||
@ -192,33 +332,50 @@ def asset_str(agent):
|
|||||||
action = step_result['action_name']
|
action = step_result['action_name']
|
||||||
valid = step_result['action_valid']
|
valid = step_result['action_valid']
|
||||||
col_names = [x.name for x in step_result['collisions']]
|
col_names = [x.name for x in step_result['collisions']]
|
||||||
if any(c.AGENT in name for name in col_names):
|
if any(Constants.AGENT in name for name in col_names):
|
||||||
return 'agent_collision', 'blank'
|
return 'agent_collision', 'blank'
|
||||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names:
|
||||||
return c.AGENT, 'invalid'
|
return Constants.AGENT, 'invalid'
|
||||||
elif valid and not EnvActions.is_move(action):
|
elif valid and not EnvActions.is_move(action):
|
||||||
return c.AGENT, 'valid'
|
return Constants.AGENT, 'valid'
|
||||||
elif valid and EnvActions.is_move(action):
|
elif valid and EnvActions.is_move(action):
|
||||||
return c.AGENT, 'move'
|
return Constants.AGENT, 'move'
|
||||||
else:
|
else:
|
||||||
return c.AGENT, 'idle'
|
return Constants.AGENT, 'idle'
|
||||||
else:
|
else:
|
||||||
return c.AGENT, 'idle'
|
return Constants.AGENT, 'idle'
|
||||||
|
|
||||||
|
|
||||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||||
|
"""
|
||||||
|
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||||
|
There are three combinations of settings:
|
||||||
|
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||||
|
Allow only manhattan: Distance(a, b) == 1
|
||||||
|
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||||
|
|
||||||
|
|
||||||
|
:param coordiniates_or_tiles: A set of coordinates.
|
||||||
|
:type coordiniates_or_tiles: Tiles
|
||||||
|
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||||
|
:type: bool
|
||||||
|
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||||
|
:type: bool
|
||||||
|
|
||||||
|
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||||
|
:rtype: nx.Graph
|
||||||
|
"""
|
||||||
assert allow_euclidean_connections or allow_manhattan_connections
|
assert allow_euclidean_connections or allow_manhattan_connections
|
||||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||||
graph = nx.Graph()
|
graph = nx.Graph()
|
||||||
for a, b in possible_connections:
|
for a, b in possible_connections:
|
||||||
diff = abs(np.subtract(a, b))
|
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||||
if not max(diff) > 1:
|
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||||
if allow_manhattan_connections and allow_euclidean_connections:
|
graph.add_edge(a, b)
|
||||||
graph.add_edge(a, b)
|
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||||
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
|
graph.add_edge(a, b)
|
||||||
graph.add_edge(a, b)
|
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
graph.add_edge(a, b)
|
||||||
graph.add_edge(a, b)
|
|
||||||
return graph
|
return graph
|
||||||
|
@ -4,6 +4,22 @@ from gym.wrappers.frame_stack import FrameStack
|
|||||||
|
|
||||||
|
|
||||||
class AgentRenderOptions(object):
|
class AgentRenderOptions(object):
|
||||||
|
"""
|
||||||
|
Class that specifies the available options for the way agents are represented in the env observation.
|
||||||
|
|
||||||
|
SEPERATE:
|
||||||
|
Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot)
|
||||||
|
|
||||||
|
COMBINED:
|
||||||
|
For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE))
|
||||||
|
|
||||||
|
LEVEL:
|
||||||
|
The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall)
|
||||||
|
|
||||||
|
NOT:
|
||||||
|
The position of individual agents can not be read from the observation.
|
||||||
|
"""
|
||||||
|
|
||||||
SEPERATE = 'seperate'
|
SEPERATE = 'seperate'
|
||||||
COMBINED = 'combined'
|
COMBINED = 'combined'
|
||||||
LEVEL = 'lvl'
|
LEVEL = 'lvl'
|
||||||
@ -11,24 +27,61 @@ class AgentRenderOptions(object):
|
|||||||
|
|
||||||
|
|
||||||
class MovementProperties(NamedTuple):
|
class MovementProperties(NamedTuple):
|
||||||
|
"""
|
||||||
|
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Allow the manhattan style movement on a grid (move to cells that are connected by square edges)."""
|
||||||
allow_square_movement: bool = True
|
allow_square_movement: bool = True
|
||||||
|
|
||||||
|
"""Allow diagonal movement on the grid (move to cells that are connected by square corners)."""
|
||||||
allow_diagonal_movement: bool = False
|
allow_diagonal_movement: bool = False
|
||||||
|
|
||||||
|
"""Allow the agent to just do nothing; not move (NO-OP)."""
|
||||||
allow_no_op: bool = False
|
allow_no_op: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ObservationProperties(NamedTuple):
|
class ObservationProperties(NamedTuple):
|
||||||
# Todo: Add Description
|
"""
|
||||||
|
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""How to represent agents in the observation space. This may also alters the obs-shape."""
|
||||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||||
|
|
||||||
|
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
|
||||||
omit_agent_self: bool = True
|
omit_agent_self: bool = True
|
||||||
|
|
||||||
|
"""Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs.
|
||||||
|
The additional slice can be filled with any number"""
|
||||||
additional_agent_placeholder: Union[None, str, int] = None
|
additional_agent_placeholder: Union[None, str, int] = None
|
||||||
|
|
||||||
|
"""Whether to cast shadows (make floortiles and items hidden).; """
|
||||||
cast_shadows: bool = True
|
cast_shadows: bool = True
|
||||||
|
|
||||||
|
"""Frame Stacking is a methode do give some temporal information to the agents.
|
||||||
|
This paramters controls how many "old-frames" """
|
||||||
frames_to_stack: int = 0
|
frames_to_stack: int = 0
|
||||||
pomdp_r: int = 0
|
|
||||||
|
"""Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken
|
||||||
|
accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`.
|
||||||
|
A 'pomdp_r' of 0 always returns the full env == no partial observability."""
|
||||||
|
pomdp_r: int = 2
|
||||||
|
|
||||||
|
"""Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be
|
||||||
|
operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option.
|
||||||
|
However, this is not necesarry at all.
|
||||||
|
"""
|
||||||
indicate_door_area: bool = False
|
indicate_door_area: bool = False
|
||||||
|
|
||||||
|
"""Whether to add the agents normalized global position as float values (2,1) to a seperate information slice.
|
||||||
|
More optional informations are to come.
|
||||||
|
"""
|
||||||
show_global_position_info: bool = False
|
show_global_position_info: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
|
"""todo @romue404"""
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
@ -215,7 +215,7 @@ if __name__ == '__main__':
|
|||||||
clean_amount=0.34,
|
clean_amount=0.34,
|
||||||
max_spawn_amount=0.1, max_global_amount=20,
|
max_spawn_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
dirt_smear_amount=0.0)
|
||||||
item_props = ItemProperties(n_items=10,
|
item_props = ItemProperties(n_items=10,
|
||||||
spawn_frequency=30, n_drop_off_locations=2,
|
spawn_frequency=30, n_drop_off_locations=2,
|
||||||
max_agent_inventory_capacity=15)
|
max_agent_inventory_capacity=15)
|
||||||
@ -349,6 +349,7 @@ if __name__ == '__main__':
|
|||||||
# Env Init & Model kwargs definition
|
# Env Init & Model kwargs definition
|
||||||
if model_cls.__name__ in ["PPO", "A2C"]:
|
if model_cls.__name__ in ["PPO", "A2C"]:
|
||||||
# env_factory = env_class(**env_kwargs)
|
# env_factory = env_class(**env_kwargs)
|
||||||
|
|
||||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||||
for _ in range(6)], start_method="spawn")
|
for _ in range(6)], start_method="spawn")
|
||||||
model_kwargs = policy_model_kwargs()
|
model_kwargs = policy_model_kwargs()
|
||||||
|
@ -213,7 +213,8 @@ if __name__ == '__main__':
|
|||||||
env_factory.save_params(param_path)
|
env_factory.save_params(param_path)
|
||||||
|
|
||||||
# EnvMonitor Init
|
# EnvMonitor Init
|
||||||
callbacks = [EnvMonitor(env_factory)]
|
env_monitor = EnvMonitor(env_factory)
|
||||||
|
callbacks = [env_monitor]
|
||||||
|
|
||||||
# Model Init
|
# Model Init
|
||||||
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
||||||
@ -233,7 +234,7 @@ if __name__ == '__main__':
|
|||||||
model.save(save_path)
|
model.save(save_path)
|
||||||
|
|
||||||
# Monitor Save
|
# Monitor Save
|
||||||
callbacks[0].save_run(combination_path / 'monitor.pick',
|
env_monitor.save_run(combination_path / 'monitor.pick',
|
||||||
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
||||||
|
|
||||||
# Better be save then sorry: Clean up!
|
# Better be save then sorry: Clean up!
|
||||||
|
Loading…
x
Reference in New Issue
Block a user