Adjustments and Documentation
This commit is contained in:
parent
3e19970a60
commit
0218f8f4e9
@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.WALLS: walls})
|
||||
self._entities.add_additional_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.FREE_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
self._entities.add_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||
@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
||||
entity_kwargs=dict(context=floor)
|
||||
)
|
||||
self._entities.register_additional_items({c.DOORS: doors})
|
||||
self._entities.add_additional_items({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.actions_hook:
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
self._actions.add_additional_items(additional_actions)
|
||||
|
||||
# Agents
|
||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||
@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
|
||||
if self._injected_agents:
|
||||
initialized_injections = list()
|
||||
for i, injection in enumerate(self._injected_agents):
|
||||
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||
initialized_injections.append(agents[-1])
|
||||
self._initialized_injections = initialized_injections
|
||||
self._entities.register_additional_items({c.AGENT: agents})
|
||||
self._entities.add_additional_items({c.AGENT: agents})
|
||||
|
||||
if self.obs_prop.additional_agent_placeholder is not None:
|
||||
# TODO: Make this accept Lists for multiple placeholders
|
||||
@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
|
||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||
)
|
||||
|
||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# Additional Entitites from SubEnvs
|
||||
if additional_entities := self.entities_hook:
|
||||
self._entities.register_additional_items(additional_entities)
|
||||
self._entities.add_additional_items(additional_entities)
|
||||
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_positions = GlobalPositions(self._level_shape)
|
||||
# This moved into the GlobalPosition object
|
||||
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
|
||||
# Return
|
||||
return self._entities
|
||||
@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||
# Actions
|
||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
||||
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
|
||||
x_new = agent.x + x_diff
|
||||
y_new = agent.y + y_diff
|
||||
|
||||
|
@ -72,15 +72,15 @@ class EnvObject(Object):
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL
|
||||
|
||||
def __init__(self, register, **kwargs):
|
||||
def __init__(self, collection, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
self._register = register
|
||||
self._collection = collection
|
||||
|
||||
def change_register(self, register):
|
||||
register.register_item(self)
|
||||
self._register.delete_env_object(self)
|
||||
self._register = register
|
||||
return self._register == register
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
# With Rendering
|
||||
|
||||
|
||||
@ -153,7 +153,7 @@ class MoveableEntity(Entity):
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
@ -371,13 +371,13 @@ class Door(Entity):
|
||||
def _open(self):
|
||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||
self._state = c.OPEN_DOOR
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self.connectivity.remove_node(self.pos)
|
||||
self._state = c.CLOSED_DOOR
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def is_linked(self, old_pos, new_pos):
|
||||
try:
|
||||
|
@ -13,11 +13,11 @@ from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
##########################################################################
|
||||
# ##################### Base Register Definition ####################### #
|
||||
# ################## Base Collections Definition ####################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class ObjectRegister:
|
||||
class ObjectCollection:
|
||||
_accepted_objects = Object
|
||||
|
||||
@property
|
||||
@ -25,59 +25,59 @@ class ObjectRegister:
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register = dict()
|
||||
self._collection = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
return len(self._collection)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def register_item(self, other: _accepted_objects):
|
||||
def add_item(self, other: _accepted_objects):
|
||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||
f'{self._accepted_objects}, ' \
|
||||
f'but were {other.__class__}.,'
|
||||
self._register.update({other.name: other})
|
||||
self._collection.update({other.name: other})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: List[_accepted_objects]):
|
||||
def add_additional_items(self, others: List[_accepted_objects]):
|
||||
for other in others:
|
||||
self.register_item(other)
|
||||
self.add_item(other)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._register.keys()
|
||||
return self._collection.keys()
|
||||
|
||||
def values(self):
|
||||
return self._register.values()
|
||||
return self._collection.values()
|
||||
|
||||
def items(self):
|
||||
return self._register.items()
|
||||
return self._collection.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
||||
return next(i for i, v in enumerate(self._collection.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._register) - abs(item)
|
||||
item = len(self._collection) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
||||
return next(v for i, v in enumerate(self._collection.values()) if i == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
try:
|
||||
return self._register[item]
|
||||
return self._collection[item]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{self._register}]'
|
||||
return f'{self.__class__.__name__}[{self._collection}]'
|
||||
|
||||
|
||||
class EnvObjectRegister(ObjectRegister):
|
||||
class EnvObjectCollection(ObjectCollection):
|
||||
|
||||
_accepted_objects = EnvObject
|
||||
|
||||
@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
|
||||
is_blocking_light: bool = False,
|
||||
can_collide: bool = False,
|
||||
can_be_shadowed: bool = True, **kwargs):
|
||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||
super(EnvObjectCollection, self).__init__(*args, **kwargs)
|
||||
self._shape = obs_shape
|
||||
self._array = None
|
||||
self._individual_slices = individual_slices
|
||||
@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.can_collide = can_collide
|
||||
|
||||
def register_item(self, other: EnvObject):
|
||||
super(EnvObjectRegister, self).register_item(other)
|
||||
def add_item(self, other: EnvObject):
|
||||
super(EnvObjectCollection, self).add_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
else:
|
||||
@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
|
||||
if self._individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
else:
|
||||
self.notify_change_to_free(self._register[name])
|
||||
self.notify_change_to_free(self._collection[name])
|
||||
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
||||
# in the observation array are result of enumeration. They can overide each other.
|
||||
# Todo: Find a better solution
|
||||
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
|
||||
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
|
||||
self._refresh_arrays()
|
||||
del self._register[name]
|
||||
del self._collection[name]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
|
||||
del self[name]
|
||||
|
||||
|
||||
class EntityRegister(EnvObjectRegister, ABC):
|
||||
class EntityCollection(EnvObjectCollection, ABC):
|
||||
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
register_obj = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
||||
collection = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, collection, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
collection.add_additional_items(entities)
|
||||
return collection
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, level_shape, *args, **kwargs):
|
||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super(EntityRegister, self).__delitem__(name)
|
||||
super(EntityCollection, self).__delitem__(name)
|
||||
|
||||
def as_array(self):
|
||||
if self._lazy_eval_transforms:
|
||||
@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return None
|
||||
|
||||
|
||||
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
class BoundEnvObjCollection(EnvObjectCollection, ABC):
|
||||
|
||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
class MovingEntityObjectCollection(EntityCollection, ABC):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
|
||||
|
||||
def notify_change_to_value(self, entity):
|
||||
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
||||
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
|
||||
if entity.last_pos != c.NO_POS:
|
||||
try:
|
||||
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
||||
@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ################# Objects and Entity Registers ####################### #
|
||||
# ################# Objects and Entity Collection ###################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class GlobalPositions(EnvObjectRegister):
|
||||
class GlobalPositions(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = GlobalPosition
|
||||
|
||||
@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||
for _, agent in enumerate(agents)]
|
||||
# noinspection PyTypeChecker
|
||||
self.register_additional_items(global_positions)
|
||||
self.add_additional_items(global_positions)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
return None
|
||||
|
||||
|
||||
class PlaceHolders(EnvObjectRegister):
|
||||
class PlaceHolders(EnvObjectCollection):
|
||||
_accepted_objects = PlaceHolder
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
if isinstance(values, (str, numbers.Number)):
|
||||
values = [values]
|
||||
register_obj = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||
collection = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
|
||||
**object_kwargs if object_kwargs is not None else {})
|
||||
for i, value in enumerate(values)]
|
||||
register_obj.register_additional_items(objects)
|
||||
return register_obj
|
||||
collection.add_additional_items(objects)
|
||||
return collection
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
|
||||
return self._array
|
||||
|
||||
|
||||
class Entities(ObjectRegister):
|
||||
_accepted_objects = EntityRegister
|
||||
class Entities(ObjectCollection):
|
||||
_accepted_objects = EntityCollection
|
||||
|
||||
@property
|
||||
def arrays(self):
|
||||
@ -352,7 +352,7 @@ class Entities(ObjectRegister):
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._register.keys())
|
||||
return list(self._collection.keys())
|
||||
|
||||
def __init__(self):
|
||||
super(Entities, self).__init__()
|
||||
@ -360,21 +360,21 @@ class Entities(ObjectRegister):
|
||||
def iter_individual_entitites(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def register_item(self, other: dict):
|
||||
def add_item(self, other: dict):
|
||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||
"This group of entities has already been registered!"
|
||||
self._register.update(other)
|
||||
"This group of entities has already been added!"
|
||||
self._collection.update(other)
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: Dict):
|
||||
return self.register_item(others)
|
||||
def add_additional_items(self, others: Dict):
|
||||
return self.add_item(others)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
return found_entities
|
||||
|
||||
|
||||
class Walls(EntityRegister):
|
||||
class Walls(EntityCollection):
|
||||
_accepted_objects = Wall
|
||||
|
||||
def as_array(self):
|
||||
@ -396,7 +396,7 @@ class Walls(EntityRegister):
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
tiles.add_additional_items(
|
||||
[cls._accepted_objects(pos, tiles)
|
||||
for pos in argwhere_coordinates]
|
||||
)
|
||||
@ -441,7 +441,7 @@ class Floors(Walls):
|
||||
return {}
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
class Agents(MovingEntityObjectCollection):
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
|
||||
old_agent = self[key]
|
||||
self[key].tile.leave(self[key])
|
||||
agent._name = old_agent.name
|
||||
self._register[agent.name] = agent
|
||||
self._collection[agent.name] = agent
|
||||
|
||||
|
||||
class Doors(EntityRegister):
|
||||
class Doors(EntityCollection):
|
||||
|
||||
def __init__(self, *args, have_area: bool = False, **kwargs):
|
||||
self.have_area = have_area
|
||||
@ -490,7 +490,7 @@ class Doors(EntityRegister):
|
||||
return super(Doors, self).as_array()
|
||||
|
||||
|
||||
class Actions(ObjectRegister):
|
||||
class Actions(ObjectCollection):
|
||||
_accepted_objects = Action
|
||||
|
||||
@property
|
||||
@ -507,22 +507,22 @@ class Actions(ObjectRegister):
|
||||
|
||||
# Move this to Baseclass, Env init?
|
||||
if self.allow_square_movement:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.square_move()])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.square_move()])
|
||||
if self.allow_diagonal_movement:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.diagonal_move()])
|
||||
self._movement_actions = self._register.copy()
|
||||
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.diagonal_move()])
|
||||
self._movement_actions = self._collection.copy()
|
||||
if self.can_use_doors:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||
if self.allow_no_op:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||
|
||||
def is_moving_action(self, action: Union[int]):
|
||||
return action in self.movement_actions.values()
|
||||
|
||||
|
||||
class Zones(ObjectRegister):
|
||||
class Zones(ObjectCollection):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
@ -551,5 +551,5 @@ class Zones(ObjectRegister):
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
def add_additional_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject):
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject):
|
||||
return attr_dict
|
||||
|
||||
|
||||
class BatteriesRegister(EnvObjectRegister):
|
||||
class BatteriesRegister(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = Battery
|
||||
|
||||
@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister):
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(batteries)
|
||||
self.add_additional_items(batteries)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
@ -140,7 +140,7 @@ class ChargePod(Entity):
|
||||
return summary
|
||||
|
||||
|
||||
class ChargePods(EntityRegister):
|
||||
class ChargePods(EntityCollection):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
|
@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
from environments.factory.base.registers import Entities, EntityCollection
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
@ -73,7 +73,7 @@ class Destination(Entity):
|
||||
return state_summary
|
||||
|
||||
|
||||
class Destinations(EntityRegister):
|
||||
class Destinations(EntityCollection):
|
||||
|
||||
_accepted_objects = Destination
|
||||
|
||||
@ -208,13 +208,13 @@ class DestFactory(BaseFactory):
|
||||
n_dest_to_spawn = len(destinations_to_spawn)
|
||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
self[c.DEST].add_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
self[c.DEST].add_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
@ -231,7 +231,7 @@ class DestFactory(BaseFactory):
|
||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||
for dest in list(self[c.DEST].values()):
|
||||
if dest.is_considered_reached:
|
||||
dest.change_register(self[c.DEST])
|
||||
dest.change_parent_collection(self[c.DEST])
|
||||
self._dest_spawn_timer[dest.name] = 0
|
||||
self.print(f'{dest.name} is reached now, removing...')
|
||||
else:
|
||||
|
@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
from environments.factory.base.registers import Entities, EntityCollection
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.utility_classes import ObservationProperties
|
||||
@ -61,7 +61,7 @@ class Dirt(Entity):
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
@ -69,7 +69,7 @@ class Dirt(Entity):
|
||||
return state_dict
|
||||
|
||||
|
||||
class DirtRegister(EntityRegister):
|
||||
class DirtRegister(EntityCollection):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
|
||||
@ -93,7 +93,7 @@ class DirtRegister(EntityRegister):
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.register_item(dirt)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||
|
@ -5,10 +5,10 @@ import numpy as np
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import Floors, Entities, EntityRegister
|
||||
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||
|
||||
|
||||
class Machines(EntityRegister):
|
||||
class Machines(EntityCollection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
@ -53,17 +53,17 @@ class Item(Entity):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
assert self._register.__class__.__name__ != ItemRegister.__class__
|
||||
assert self._collection.__class__.__name__ != ItemRegister.__class__
|
||||
self._tile = no_pos_tile
|
||||
|
||||
|
||||
class ItemRegister(EntityRegister):
|
||||
class ItemRegister(EntityCollection):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.register_additional_items(items)
|
||||
self.add_additional_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
@ -71,7 +71,7 @@ class ItemRegister(EntityRegister):
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundEnvObjRegister):
|
||||
class Inventory(BoundEnvObjCollection):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister):
|
||||
return item_to_pop
|
||||
|
||||
|
||||
class Inventories(ObjectRegister):
|
||||
class Inventories(ObjectCollection):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
@ -114,7 +114,7 @@ class Inventories(ObjectRegister):
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(inventories)
|
||||
self.add_additional_items(inventories)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
@ -161,7 +161,7 @@ class DropOffLocation(Entity):
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
|
||||
|
||||
class DropOffLocations(EntityRegister):
|
||||
class DropOffLocations(EntityCollection):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
|
||||
@ -250,7 +250,7 @@ class ItemFactory(BaseFactory):
|
||||
reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
item.change_register(inventory)
|
||||
item.change_parent_collection(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
||||
|
@ -7,47 +7,76 @@ import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
|
||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||
|
||||
LEVELS_DIR = 'levels'
|
||||
STEPS_START = 1
|
||||
|
||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
||||
'dirty_tile_count', 'terminal_observation', 'episode']
|
||||
"""
|
||||
This file is used for:
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
|
||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments
|
||||
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
# Not used anymore? Clean!
|
||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
|
||||
# Constants
|
||||
class Constants:
|
||||
WALL = '#'
|
||||
WALLS = 'Walls'
|
||||
FLOOR = 'Floor'
|
||||
DOOR = 'D'
|
||||
DANGER_ZONE = 'x'
|
||||
LEVEL = 'Level'
|
||||
AGENT = 'Agent'
|
||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION'
|
||||
FREE_CELL = 0
|
||||
OCCUPIED_CELL = 1
|
||||
SHADOWED_CELL = -1
|
||||
ACCESS_DOOR_CELL = 1/3
|
||||
OPEN_DOOR_CELL = 2/3
|
||||
CLOSED_DOOR_CELL = 3/3
|
||||
NO_POS = (-9999, -9999)
|
||||
|
||||
DOORS = 'Doors'
|
||||
CLOSED_DOOR = 'closed'
|
||||
OPEN_DOOR = 'open'
|
||||
ACCESS_DOOR = 'access'
|
||||
"""
|
||||
String based mapping. Use these to handle keys or define values, which can be then be used globaly.
|
||||
Please use class inheritance when defining new environments.
|
||||
"""
|
||||
|
||||
ACTION = 'action'
|
||||
COLLISION = 'collision'
|
||||
VALID = True
|
||||
NOT_VALID = False
|
||||
WALL = '#' # Wall tile identifier for resolving the string based map files.
|
||||
DOOR = 'D' # Door identifier for resolving the string based map files.
|
||||
DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files.
|
||||
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and sets (collections).
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections).
|
||||
DOORS = 'Doors' # Identifier of Door-objects and sets (collections).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and sets (collections).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and sets (collections).
|
||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections).
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
FREE_CELL = 0 # Free-Cell value used in observation
|
||||
OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||
SHADOWED_CELL = -1 # Shadowed-Cell value used in observation
|
||||
ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation
|
||||
OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation
|
||||
CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation
|
||||
|
||||
NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid)
|
||||
|
||||
CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state
|
||||
OPEN_DOOR = 'open' # Identifier to compare door-is-open state
|
||||
# ACCESS_DOOR = 'access' # Identifier to compare access positions
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and sets (collections).
|
||||
COLLISION = 'collision' # Identifier to use in the context of collitions.
|
||||
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||
|
||||
|
||||
class EnvActions:
|
||||
"""
|
||||
String based mapping. Use these to identifiy actions, can be used globaly.
|
||||
Please use class inheritance when defining new environments with new actions.
|
||||
"""
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
@ -63,24 +92,77 @@ class EnvActions:
|
||||
NOOP = 'no_op'
|
||||
USE_DOOR = 'use_door'
|
||||
|
||||
_ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||
{NORTH: (-1, 0), NORTHEAST: (-1, 1),
|
||||
EAST: (0, 1), SOUTHEAST: (1, 1),
|
||||
SOUTH: (1, 0), SOUTHWEST: (1, -1),
|
||||
WEST: (0, -1), NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_move(cls, other):
|
||||
return any([other == direction for direction in cls.movement_actions()])
|
||||
def is_move(cls, action):
|
||||
"""
|
||||
Classmethod; checks if given action is a movement action or not. Depending on the env. configuration,
|
||||
Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal.
|
||||
|
||||
:param action: Action to be checked
|
||||
:type action: str
|
||||
:return: Whether the given action is a movement action.
|
||||
:rtype: bool
|
||||
"""
|
||||
return any([action == direction for direction in cls.movement_actions()])
|
||||
|
||||
@classmethod
|
||||
def square_move(cls):
|
||||
"""
|
||||
Classmethod; return a list of movement actions that are considered square or `manhattan` style movements.
|
||||
|
||||
:return: A list of movement actions.
|
||||
:rtype: list(str)
|
||||
"""
|
||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||
|
||||
@classmethod
|
||||
def diagonal_move(cls):
|
||||
"""
|
||||
Classmethod; return a list of movement actions that are considered diagonal movements.
|
||||
|
||||
:return: A list of movement actions.
|
||||
:rtype: list(str)
|
||||
"""
|
||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||
|
||||
@classmethod
|
||||
def movement_actions(cls):
|
||||
"""
|
||||
Classmethod; return a list of all available movement actions.
|
||||
Please note, that this is indipendent from the env. properties
|
||||
|
||||
:return: A list of movement actions.
|
||||
:rtype: list(str)
|
||||
"""
|
||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||
|
||||
@classmethod
|
||||
def resolve_movement_action_to_coords(cls, action):
|
||||
"""
|
||||
Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for.
|
||||
How does the current entity coordinate change if it performs the given action?
|
||||
Please note, this is indipendent from the env. properties
|
||||
|
||||
:return: Delta coorinates.
|
||||
:rtype: tuple(int, int)
|
||||
"""
|
||||
return cls._ACTIONMAP[action]
|
||||
|
||||
|
||||
class RewardsBase(NamedTuple):
|
||||
"""
|
||||
Value based mapping. Use these to define reward values for specific conditions (i.e. the action
|
||||
in a given context), can be used globaly.
|
||||
Please use class inheritance when defining new environments with new rewards.
|
||||
"""
|
||||
MOVEMENTS_VALID: float = -0.001
|
||||
MOVEMENTS_FAIL: float = -0.05
|
||||
NOOP: float = -0.01
|
||||
@ -89,23 +171,31 @@ class RewardsBase(NamedTuple):
|
||||
COLLISION: float = -0.5
|
||||
|
||||
|
||||
m = EnvActions
|
||||
c = Constants
|
||||
|
||||
ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
|
||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
||||
m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1),
|
||||
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ObservationTranslator:
|
||||
|
||||
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
||||
*per_agent_named_obs_space: Dict[str, dict],
|
||||
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||
placeholder_fill_value: Union[int, str] = 'N'):
|
||||
"""
|
||||
This is a helper class, which converts agents observations from joined environments.
|
||||
For example, agents trained in different environments may expect different observations.
|
||||
This class translates from larger observations spaces to smaller.
|
||||
A string identifier based approach is used.
|
||||
Currently, it is not possible to mix different obs shapes.
|
||||
|
||||
:param obs_shape_2d: The shape of the observation the agents expect.
|
||||
:type obs_shape_2d: tuple(int, int)
|
||||
|
||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||
:type this_named_observation_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
"""
|
||||
|
||||
assert len(obs_shape_2d) == 2
|
||||
self.obs_shape = obs_shape_2d
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
@ -119,7 +209,7 @@ class ObservationTranslator:
|
||||
self.random_fill = None
|
||||
|
||||
self._this_named_obs_space = this_named_observation_space
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_space)
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||
|
||||
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||
@ -137,6 +227,19 @@ class ObservationTranslator:
|
||||
class ActionTranslator:
|
||||
|
||||
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||
"""
|
||||
This is a helper class, which converts agents action spaces to a joined environments action space.
|
||||
For example, agents trained in different environments may have different action spaces.
|
||||
This class translates from smaller individual agent action spaces to larger joined spaces.
|
||||
A string identifier based approach is used.
|
||||
|
||||
:param target_named_action_space: Joined `Named action space` for the current environment.
|
||||
:type target_named_action_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
||||
:type per_agent_named_action_space: Dict[str, dict]
|
||||
"""
|
||||
|
||||
self._target_named_action_space = target_named_action_space
|
||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||
@ -155,6 +258,16 @@ class ActionTranslator:
|
||||
|
||||
# Utility functions
|
||||
def parse_level(path):
|
||||
"""
|
||||
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
||||
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
||||
|
||||
:param path: Path to the `level` or `map` file on harddrive.
|
||||
:type path: os.Pathlike
|
||||
|
||||
:return: The read string representation of the `level` or `map`
|
||||
:rtype: List[List[str]]
|
||||
"""
|
||||
with path.open('r') as lvl:
|
||||
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
||||
if len(set([len(line) for line in level])) > 1:
|
||||
@ -162,29 +275,56 @@ def parse_level(path):
|
||||
return level
|
||||
|
||||
|
||||
def one_hot_level(level, wall_char: str = c.WALL):
|
||||
def one_hot_level(level, wall_char: str = Constants.WALL):
|
||||
"""
|
||||
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
||||
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
||||
Can be changed to filter for any symbol.
|
||||
|
||||
:param level: String based level representation (list of lists, see function `parse_level`).
|
||||
:param wall_char: List[List[str]]
|
||||
|
||||
:return: Binary numpy array
|
||||
:rtype: np.typing._array_like.ArrayLike
|
||||
"""
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL
|
||||
binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
||||
"""
|
||||
Given a slice (2-D Arraylike object)
|
||||
|
||||
:param slice_to_check_against: The slice to check for accessability
|
||||
:type slice_to_check_against: np.typing._array_like.ArrayLike
|
||||
|
||||
:param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys.
|
||||
:type position_to_check: tuple(int, int)
|
||||
|
||||
:return: Whether a position can be moved to.
|
||||
:rtype: bool
|
||||
"""
|
||||
x_pos, y_pos = position_to_check
|
||||
|
||||
# Check if agent colides with grid boundrys
|
||||
valid = not (
|
||||
x_pos < 0 or y_pos < 0
|
||||
or x_pos >= slice_to_check_against.shape[0]
|
||||
or y_pos >= slice_to_check_against.shape[0]
|
||||
or y_pos >= slice_to_check_against.shape[1]
|
||||
)
|
||||
|
||||
# Check for collision with level walls
|
||||
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
||||
return c.VALID if valid else c.NOT_VALID
|
||||
return Constants.VALID if valid else Constants.NOT_VALID
|
||||
|
||||
|
||||
def asset_str(agent):
|
||||
"""
|
||||
FIXME @ romue
|
||||
"""
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
@ -192,33 +332,50 @@ def asset_str(agent):
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
if any(Constants.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names:
|
||||
return Constants.AGENT, 'invalid'
|
||||
elif valid and not EnvActions.is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
return Constants.AGENT, 'valid'
|
||||
elif valid and EnvActions.is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
return Constants.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
return Constants.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
return Constants.AGENT, 'idle'
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
:type: bool
|
||||
|
||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = abs(np.subtract(a, b))
|
||||
if not max(diff) > 1:
|
||||
if allow_manhattan_connections and allow_euclidean_connections:
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
||||
graph.add_edge(a, b)
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
@ -4,6 +4,22 @@ from gym.wrappers.frame_stack import FrameStack
|
||||
|
||||
|
||||
class AgentRenderOptions(object):
|
||||
"""
|
||||
Class that specifies the available options for the way agents are represented in the env observation.
|
||||
|
||||
SEPERATE:
|
||||
Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot)
|
||||
|
||||
COMBINED:
|
||||
For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE))
|
||||
|
||||
LEVEL:
|
||||
The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall)
|
||||
|
||||
NOT:
|
||||
The position of individual agents can not be read from the observation.
|
||||
"""
|
||||
|
||||
SEPERATE = 'seperate'
|
||||
COMBINED = 'combined'
|
||||
LEVEL = 'lvl'
|
||||
@ -11,24 +27,61 @@ class AgentRenderOptions(object):
|
||||
|
||||
|
||||
class MovementProperties(NamedTuple):
|
||||
"""
|
||||
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||
"""
|
||||
|
||||
"""Allow the manhattan style movement on a grid (move to cells that are connected by square edges)."""
|
||||
allow_square_movement: bool = True
|
||||
|
||||
"""Allow diagonal movement on the grid (move to cells that are connected by square corners)."""
|
||||
allow_diagonal_movement: bool = False
|
||||
|
||||
"""Allow the agent to just do nothing; not move (NO-OP)."""
|
||||
allow_no_op: bool = False
|
||||
|
||||
|
||||
class ObservationProperties(NamedTuple):
|
||||
# Todo: Add Description
|
||||
"""
|
||||
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||
"""
|
||||
|
||||
"""How to represent agents in the observation space. This may also alters the obs-shape."""
|
||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||
|
||||
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
|
||||
omit_agent_self: bool = True
|
||||
|
||||
"""Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs.
|
||||
The additional slice can be filled with any number"""
|
||||
additional_agent_placeholder: Union[None, str, int] = None
|
||||
|
||||
"""Whether to cast shadows (make floortiles and items hidden).; """
|
||||
cast_shadows: bool = True
|
||||
|
||||
"""Frame Stacking is a methode do give some temporal information to the agents.
|
||||
This paramters controls how many "old-frames" """
|
||||
frames_to_stack: int = 0
|
||||
pomdp_r: int = 0
|
||||
|
||||
"""Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken
|
||||
accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`.
|
||||
A 'pomdp_r' of 0 always returns the full env == no partial observability."""
|
||||
pomdp_r: int = 2
|
||||
|
||||
"""Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be
|
||||
operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option.
|
||||
However, this is not necesarry at all.
|
||||
"""
|
||||
indicate_door_area: bool = False
|
||||
|
||||
"""Whether to add the agents normalized global position as float values (2,1) to a seperate information slice.
|
||||
More optional informations are to come.
|
||||
"""
|
||||
show_global_position_info: bool = False
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
|
@ -215,7 +215,7 @@ if __name__ == '__main__':
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||
dirt_smear_amount=0.0)
|
||||
item_props = ItemProperties(n_items=10,
|
||||
spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
@ -349,6 +349,7 @@ if __name__ == '__main__':
|
||||
# Env Init & Model kwargs definition
|
||||
if model_cls.__name__ in ["PPO", "A2C"]:
|
||||
# env_factory = env_class(**env_kwargs)
|
||||
|
||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||
for _ in range(6)], start_method="spawn")
|
||||
model_kwargs = policy_model_kwargs()
|
||||
|
@ -213,7 +213,8 @@ if __name__ == '__main__':
|
||||
env_factory.save_params(param_path)
|
||||
|
||||
# EnvMonitor Init
|
||||
callbacks = [EnvMonitor(env_factory)]
|
||||
env_monitor = EnvMonitor(env_factory)
|
||||
callbacks = [env_monitor]
|
||||
|
||||
# Model Init
|
||||
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
||||
@ -233,7 +234,7 @@ if __name__ == '__main__':
|
||||
model.save(save_path)
|
||||
|
||||
# Monitor Save
|
||||
callbacks[0].save_run(combination_path / 'monitor.pick',
|
||||
env_monitor.save_run(combination_path / 'monitor.pick',
|
||||
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
||||
|
||||
# Better be save then sorry: Clean up!
|
||||
|
Loading…
x
Reference in New Issue
Block a user