Adjustments and Documentation

This commit is contained in:
Steffen Illium 2022-04-11 16:15:44 +02:00
parent 3e19970a60
commit 0218f8f4e9
12 changed files with 394 additions and 182 deletions

View File

@ -156,14 +156,14 @@ class BaseFactory(gym.Env):
np.argwhere(level_array == c.OCCUPIED_CELL),
self._level_shape
)
self._entities.register_additional_items({c.WALLS: walls})
self._entities.add_additional_items({c.WALLS: walls})
# Floor
floor = Floors.from_argwhere_coordinates(
np.argwhere(level_array == c.FREE_CELL),
self._level_shape
)
self._entities.register_additional_items({c.FLOOR: floor})
self._entities.add_additional_items({c.FLOOR: floor})
# NOPOS
self._NO_POS_TILE = Floor(c.NO_POS, None)
@ -177,12 +177,12 @@ class BaseFactory(gym.Env):
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
entity_kwargs=dict(context=floor)
)
self._entities.register_additional_items({c.DOORS: doors})
self._entities.add_additional_items({c.DOORS: doors})
# Actions
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
if additional_actions := self.actions_hook:
self._actions.register_additional_items(additional_actions)
self._actions.add_additional_items(additional_actions)
# Agents
agents_to_spawn = self.n_agents-len(self._injected_agents)
@ -196,10 +196,10 @@ class BaseFactory(gym.Env):
if self._injected_agents:
initialized_injections = list()
for i, injection in enumerate(self._injected_agents):
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
initialized_injections.append(agents[-1])
self._initialized_injections = initialized_injections
self._entities.register_additional_items({c.AGENT: agents})
self._entities.add_additional_items({c.AGENT: agents})
if self.obs_prop.additional_agent_placeholder is not None:
# TODO: Make this accept Lists for multiple placeholders
@ -210,18 +210,18 @@ class BaseFactory(gym.Env):
fill_value=self.obs_prop.additional_agent_placeholder)
)
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
# Additional Entitites from SubEnvs
if additional_entities := self.entities_hook:
self._entities.register_additional_items(additional_entities)
self._entities.add_additional_items(additional_entities)
if self.obs_prop.show_global_position_info:
global_positions = GlobalPositions(self._level_shape)
# This moved into the GlobalPosition object
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
global_positions.spawn_global_position_objects(self[c.AGENT])
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
# Return
return self._entities
@ -535,7 +535,7 @@ class BaseFactory(gym.Env):
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
# Actions
x_diff, y_diff = h.ACTIONMAP[action.identifier]
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
x_new = agent.x + x_diff
y_new = agent.y + y_diff

View File

@ -72,15 +72,15 @@ class EnvObject(Object):
def encoding(self):
return c.OCCUPIED_CELL
def __init__(self, register, **kwargs):
def __init__(self, collection, **kwargs):
super(EnvObject, self).__init__(**kwargs)
self._register = register
self._collection = collection
def change_register(self, register):
register.register_item(self)
self._register.delete_env_object(self)
self._register = register
return self._register == register
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
# With Rendering
@ -153,7 +153,7 @@ class MoveableEntity(Entity):
curr_tile.leave(self)
self._tile = next_tile
self._last_tile = curr_tile
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
@ -371,13 +371,13 @@ class Door(Entity):
def _open(self):
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
self._state = c.OPEN_DOOR
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
self.time_to_close = self.auto_close_interval
def _close(self):
self.connectivity.remove_node(self.pos)
self._state = c.CLOSED_DOOR
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
def is_linked(self, old_pos, new_pos):
try:

View File

@ -13,11 +13,11 @@ from environments import helpers as h
from environments.helpers import Constants as c
##########################################################################
# ##################### Base Register Definition ####################### #
# ################## Base Collections Definition ####################### #
##########################################################################
class ObjectRegister:
class ObjectCollection:
_accepted_objects = Object
@property
@ -25,59 +25,59 @@ class ObjectRegister:
return f'{self.__class__.__name__}'
def __init__(self, *args, **kwargs):
self._register = dict()
self._collection = dict()
def __len__(self):
return len(self._register)
return len(self._collection)
def __iter__(self):
return iter(self.values())
def register_item(self, other: _accepted_objects):
def add_item(self, other: _accepted_objects):
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,'
self._register.update({other.name: other})
self._collection.update({other.name: other})
return self
def register_additional_items(self, others: List[_accepted_objects]):
def add_additional_items(self, others: List[_accepted_objects]):
for other in others:
self.register_item(other)
self.add_item(other)
return self
def keys(self):
return self._register.keys()
return self._collection.keys()
def values(self):
return self._register.values()
return self._collection.values()
def items(self):
return self._register.items()
return self._collection.items()
def _get_index(self, item):
try:
return next(i for i, v in enumerate(self._register.values()) if v == item)
return next(i for i, v in enumerate(self._collection.values()) if v == item)
except StopIteration:
return None
def __getitem__(self, item):
if isinstance(item, (int, np.int64, np.int32)):
if item < 0:
item = len(self._register) - abs(item)
item = len(self._collection) - abs(item)
try:
return next(v for i, v in enumerate(self._register.values()) if i == item)
return next(v for i, v in enumerate(self._collection.values()) if i == item)
except StopIteration:
return None
try:
return self._register[item]
return self._collection[item]
except KeyError:
return None
def __repr__(self):
return f'{self.__class__.__name__}[{self._register}]'
return f'{self.__class__.__name__}[{self._collection}]'
class EnvObjectRegister(ObjectRegister):
class EnvObjectCollection(ObjectCollection):
_accepted_objects = EnvObject
@ -90,7 +90,7 @@ class EnvObjectRegister(ObjectRegister):
is_blocking_light: bool = False,
can_collide: bool = False,
can_be_shadowed: bool = True, **kwargs):
super(EnvObjectRegister, self).__init__(*args, **kwargs)
super(EnvObjectCollection, self).__init__(*args, **kwargs)
self._shape = obs_shape
self._array = None
self._individual_slices = individual_slices
@ -99,8 +99,8 @@ class EnvObjectRegister(ObjectRegister):
self.can_be_shadowed = can_be_shadowed
self.can_collide = can_collide
def register_item(self, other: EnvObject):
super(EnvObjectRegister, self).register_item(other)
def add_item(self, other: EnvObject):
super(EnvObjectCollection, self).add_item(other)
if self._array is None:
self._array = np.zeros((1, *self._shape))
else:
@ -145,13 +145,13 @@ class EnvObjectRegister(ObjectRegister):
if self._individual_slices:
self._array = np.delete(self._array, idx, axis=0)
else:
self.notify_change_to_free(self._register[name])
self.notify_change_to_free(self._collection[name])
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
# in the observation array are result of enumeration. They can overide each other.
# Todo: Find a better solution
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
self._refresh_arrays()
del self._register[name]
del self._collection[name]
def delete_env_object(self, env_object: EnvObject):
del self[env_object.name]
@ -160,19 +160,19 @@ class EnvObjectRegister(ObjectRegister):
del self[name]
class EntityRegister(EnvObjectRegister, ABC):
class EntityCollection(EnvObjectCollection, ABC):
_accepted_objects = Entity
@classmethod
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
# objects_name = cls._accepted_objects.__name__
register_obj = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
collection = cls(*args, **kwargs)
entities = [cls._accepted_objects(tile, collection, str_ident=i,
**entity_kwargs if entity_kwargs is not None else {})
for i, tile in enumerate(tiles)]
register_obj.register_additional_items(entities)
return register_obj
collection.add_additional_items(entities)
return collection
@classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
@ -188,13 +188,13 @@ class EntityRegister(EnvObjectRegister, ABC):
return [entity.tile for entity in self]
def __init__(self, level_shape, *args, **kwargs):
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
self._lazy_eval_transforms = []
def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
obj.tile.leave(obj)
super(EntityRegister, self).__delitem__(name)
super(EntityCollection, self).__delitem__(name)
def as_array(self):
if self._lazy_eval_transforms:
@ -223,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
return None
class BoundEnvObjRegister(EnvObjectRegister, ABC):
class BoundEnvObjCollection(EnvObjectCollection, ABC):
def __init__(self, entity_to_be_bound, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -248,13 +248,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
return self._array[self.idx_by_entity(entity)]
class MovingEntityObjectRegister(EntityRegister, ABC):
class MovingEntityObjectCollection(EntityCollection, ABC):
def __init__(self, *args, **kwargs):
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
def notify_change_to_value(self, entity):
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
if entity.last_pos != c.NO_POS:
try:
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
@ -263,11 +263,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
##########################################################################
# ################# Objects and Entity Registers ####################### #
# ################# Objects and Entity Collection ###################### #
##########################################################################
class GlobalPositions(EnvObjectRegister):
class GlobalPositions(EnvObjectCollection):
_accepted_objects = GlobalPosition
@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
global_positions = [self._accepted_objects(self._shape, agent, self)
for _, agent in enumerate(agents)]
# noinspection PyTypeChecker
self.register_additional_items(global_positions)
self.add_additional_items(global_positions)
def summarize_states(self, n_steps=None):
return {}
@ -306,7 +306,7 @@ class GlobalPositions(EnvObjectRegister):
return None
class PlaceHolders(EnvObjectRegister):
class PlaceHolders(EnvObjectCollection):
_accepted_objects = PlaceHolder
def __init__(self, *args, **kwargs):
@ -320,12 +320,12 @@ class PlaceHolders(EnvObjectRegister):
# objects_name = cls._accepted_objects.__name__
if isinstance(values, (str, numbers.Number)):
values = [values]
register_obj = cls(*args, **kwargs)
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
collection = cls(*args, **kwargs)
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
**object_kwargs if object_kwargs is not None else {})
for i, value in enumerate(values)]
register_obj.register_additional_items(objects)
return register_obj
collection.add_additional_items(objects)
return collection
# noinspection DuplicatedCode
def as_array(self):
@ -343,8 +343,8 @@ class PlaceHolders(EnvObjectRegister):
return self._array
class Entities(ObjectRegister):
_accepted_objects = EntityRegister
class Entities(ObjectCollection):
_accepted_objects = EntityCollection
@property
def arrays(self):
@ -352,7 +352,7 @@ class Entities(ObjectRegister):
@property
def names(self):
return list(self._register.keys())
return list(self._collection.keys())
def __init__(self):
super(Entities, self).__init__()
@ -360,21 +360,21 @@ class Entities(ObjectRegister):
def iter_individual_entitites(self):
return iter((x for sublist in self.values() for x in sublist))
def register_item(self, other: dict):
def add_item(self, other: dict):
assert not any([key for key in other.keys() if key in self.keys()]), \
"This group of entities has already been registered!"
self._register.update(other)
"This group of entities has already been added!"
self._collection.update(other)
return self
def register_additional_items(self, others: Dict):
return self.register_item(others)
def add_additional_items(self, others: Dict):
return self.add_item(others)
def by_pos(self, pos: (int, int)):
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
return found_entities
class Walls(EntityRegister):
class Walls(EntityCollection):
_accepted_objects = Wall
def as_array(self):
@ -396,7 +396,7 @@ class Walls(EntityRegister):
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.register_additional_items(
tiles.add_additional_items(
[cls._accepted_objects(pos, tiles)
for pos in argwhere_coordinates]
)
@ -441,7 +441,7 @@ class Floors(Walls):
return {}
class Agents(MovingEntityObjectRegister):
class Agents(MovingEntityObjectCollection):
_accepted_objects = Agent
def __init__(self, *args, **kwargs):
@ -455,10 +455,10 @@ class Agents(MovingEntityObjectRegister):
old_agent = self[key]
self[key].tile.leave(self[key])
agent._name = old_agent.name
self._register[agent.name] = agent
self._collection[agent.name] = agent
class Doors(EntityRegister):
class Doors(EntityCollection):
def __init__(self, *args, have_area: bool = False, **kwargs):
self.have_area = have_area
@ -490,7 +490,7 @@ class Doors(EntityRegister):
return super(Doors, self).as_array()
class Actions(ObjectRegister):
class Actions(ObjectCollection):
_accepted_objects = Action
@property
@ -507,22 +507,22 @@ class Actions(ObjectRegister):
# Move this to Baseclass, Env init?
if self.allow_square_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction)
self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.square_move()])
if self.allow_diagonal_movement:
self.register_additional_items([self._accepted_objects(str_ident=direction)
self.add_additional_items([self._accepted_objects(str_ident=direction)
for direction in h.EnvActions.diagonal_move()])
self._movement_actions = self._register.copy()
self._movement_actions = self._collection.copy()
if self.can_use_doors:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
if self.allow_no_op:
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values()
class Zones(ObjectRegister):
class Zones(ObjectCollection):
@property
def accounting_zones(self):
@ -551,5 +551,5 @@ class Zones(ObjectRegister):
def __getitem__(self, item):
return self._zone_slices[item]
def register_additional_items(self, other: Union[str, List[str]]):
def add_additional_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')

View File

@ -4,7 +4,7 @@ import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
@ -68,7 +68,7 @@ class Battery(BoundingMixin, EnvObject):
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
@ -79,7 +79,7 @@ class Battery(BoundingMixin, EnvObject):
return attr_dict
class BatteriesRegister(EnvObjectRegister):
class BatteriesRegister(EnvObjectCollection):
_accepted_objects = Battery
@ -90,7 +90,7 @@ class BatteriesRegister(EnvObjectRegister):
def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
self.register_additional_items(batteries)
self.add_additional_items(batteries)
def summarize_states(self, n_steps=None):
# as dict with additional nesting
@ -140,7 +140,7 @@ class ChargePod(Entity):
return summary
class ChargePods(EntityRegister):
class ChargePods(EntityCollection):
_accepted_objects = ChargePod

View File

@ -9,7 +9,7 @@ from environments.factory.base.base_factory import BaseFactory
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity
@ -73,7 +73,7 @@ class Destination(Entity):
return state_summary
class Destinations(EntityRegister):
class Destinations(EntityCollection):
_accepted_objects = Destination
@ -208,13 +208,13 @@ class DestFactory(BaseFactory):
n_dest_to_spawn = len(destinations_to_spawn)
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations)
self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
self[c.DEST].register_additional_items(destinations)
self[c.DEST].add_additional_items(destinations)
for dest in destinations_to_spawn:
del self._dest_spawn_timer[dest]
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
@ -231,7 +231,7 @@ class DestFactory(BaseFactory):
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
for dest in list(self[c.DEST].values()):
if dest.is_considered_reached:
dest.change_register(self[c.DEST])
dest.change_parent_collection(self[c.DEST])
self._dest_spawn_timer[dest.name] = 0
self.print(f'{dest.name} is reached now, removing...')
else:

View File

@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, Floor
from environments.factory.base.registers import Entities, EntityRegister
from environments.factory.base.registers import Entities, EntityCollection
from environments.factory.base.renderer import RenderEntity
from environments.utility_classes import ObservationProperties
@ -61,7 +61,7 @@ class Dirt(Entity):
def set_new_amount(self, amount):
self._amount = amount
self._register.notify_change_to_value(self)
self._collection.notify_change_to_value(self)
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
@ -69,7 +69,7 @@ class Dirt(Entity):
return state_dict
class DirtRegister(EntityRegister):
class DirtRegister(EntityCollection):
_accepted_objects = Dirt
@ -93,7 +93,7 @@ class DirtRegister(EntityRegister):
dirt = self.by_pos(tile.pos)
if dirt is None:
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
self.register_item(dirt)
self.add_item(dirt)
else:
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))

View File

@ -5,10 +5,10 @@ import numpy as np
from environments.factory.base.objects import Agent, Entity, Action
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
from environments.factory.base.objects import Floor
from environments.factory.base.registers import Floors, Entities, EntityRegister
from environments.factory.base.registers import Floors, Entities, EntityCollection
class Machines(EntityRegister):
class Machines(EntityCollection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@ -9,7 +9,7 @@ from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments import helpers as h
from environments.factory.base.objects import Agent, Entity, Action, Floor
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
from environments.factory.base.renderer import RenderEntity
@ -53,17 +53,17 @@ class Item(Entity):
self._auto_despawn = auto_despawn
def set_tile_to(self, no_pos_tile):
assert self._register.__class__.__name__ != ItemRegister.__class__
assert self._collection.__class__.__name__ != ItemRegister.__class__
self._tile = no_pos_tile
class ItemRegister(EntityRegister):
class ItemRegister(EntityCollection):
_accepted_objects = Item
def spawn_items(self, tiles: List[Floor]):
items = [Item(tile, self) for tile in tiles]
self.register_additional_items(items)
self.add_additional_items(items)
def despawn_items(self, items: List[Item]):
items = [items] if isinstance(items, Item) else items
@ -71,7 +71,7 @@ class ItemRegister(EntityRegister):
del self[item]
class Inventory(BoundEnvObjRegister):
class Inventory(BoundEnvObjCollection):
@property
def name(self):
@ -98,7 +98,7 @@ class Inventory(BoundEnvObjRegister):
return item_to_pop
class Inventories(ObjectRegister):
class Inventories(ObjectCollection):
_accepted_objects = Inventory
is_blocking_light = False
@ -114,7 +114,7 @@ class Inventories(ObjectRegister):
def spawn_inventories(self, agents, capacity):
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
for _, agent in enumerate(agents)]
self.register_additional_items(inventories)
self.add_additional_items(inventories)
def idx_by_entity(self, entity):
try:
@ -161,7 +161,7 @@ class DropOffLocation(Entity):
return super().summarize_state(n_steps=n_steps)
class DropOffLocations(EntityRegister):
class DropOffLocations(EntityCollection):
_accepted_objects = DropOffLocation
@ -250,7 +250,7 @@ class ItemFactory(BaseFactory):
reason=a.ITEM_ACTION, info=info_dict)
return valid, reward
elif item := self[c.ITEM].by_pos(agent.pos):
item.change_register(inventory)
item.change_parent_collection(inventory)
item.set_tile_to(self._NO_POS_TILE)
self.print(f'{agent.name} just picked up an item at {agent.pos}')
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}

View File

@ -7,47 +7,76 @@ import numpy as np
from numpy.typing import ArrayLike
from stable_baselines3 import PPO, DQN, A2C
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
LEVELS_DIR = 'levels'
STEPS_START = 1
"""
This file is used for:
1. string based definition
Use a class like `Constants`, to define attributes, which then reveal strings.
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
When defining new envs, use class inheritance.
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
'dirty_tile_count', 'terminal_observation', 'episode']
2. utility function definition
There are static utility functions which are not bound to a specific environment.
In this file they are defined to be used across the entire package.
"""
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
# Not used anymore? Clean!
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
'episode']
# Constants
class Constants:
WALL = '#'
WALLS = 'Walls'
FLOOR = 'Floor'
DOOR = 'D'
DANGER_ZONE = 'x'
LEVEL = 'Level'
AGENT = 'Agent'
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
GLOBAL_POSITION = 'GLOBAL_POSITION'
FREE_CELL = 0
OCCUPIED_CELL = 1
SHADOWED_CELL = -1
ACCESS_DOOR_CELL = 1/3
OPEN_DOOR_CELL = 2/3
CLOSED_DOOR_CELL = 3/3
NO_POS = (-9999, -9999)
DOORS = 'Doors'
CLOSED_DOOR = 'closed'
OPEN_DOOR = 'open'
ACCESS_DOOR = 'access'
"""
String based mapping. Use these to handle keys or define values, which can be then be used globaly.
Please use class inheritance when defining new environments.
"""
ACTION = 'action'
COLLISION = 'collision'
VALID = True
NOT_VALID = False
WALL = '#' # Wall tile identifier for resolving the string based map files.
DOOR = 'D' # Door identifier for resolving the string based map files.
DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files.
WALLS = 'Walls' # Identifier of Wall-objects and sets (collections).
FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections).
DOORS = 'Doors' # Identifier of Door-objects and sets (collections).
LEVEL = 'Level' # Identifier of Level-objects and sets (collections).
AGENT = 'Agent' # Identifier of Agent-objects and sets (collections).
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections).
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
FREE_CELL = 0 # Free-Cell value used in observation
OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
SHADOWED_CELL = -1 # Shadowed-Cell value used in observation
ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation
OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation
CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation
NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid)
CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state
OPEN_DOOR = 'open' # Identifier to compare door-is-open state
# ACCESS_DOOR = 'access' # Identifier to compare access positions
ACTION = 'action' # Identifier of Action-objects and sets (collections).
COLLISION = 'collision' # Identifier to use in the context of collitions.
VALID = True # Identifier to rename boolean values in the context of actions.
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
class EnvActions:
"""
String based mapping. Use these to identifiy actions, can be used globaly.
Please use class inheritance when defining new environments with new actions.
"""
# Movements
NORTH = 'north'
EAST = 'east'
@ -63,24 +92,77 @@ class EnvActions:
NOOP = 'no_op'
USE_DOOR = 'use_door'
_ACTIONMAP = defaultdict(lambda: (0, 0),
{NORTH: (-1, 0), NORTHEAST: (-1, 1),
EAST: (0, 1), SOUTHEAST: (1, 1),
SOUTH: (1, 0), SOUTHWEST: (1, -1),
WEST: (0, -1), NORTHWEST: (-1, -1)
}
)
@classmethod
def is_move(cls, other):
return any([other == direction for direction in cls.movement_actions()])
def is_move(cls, action):
"""
Classmethod; checks if given action is a movement action or not. Depending on the env. configuration,
Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal.
:param action: Action to be checked
:type action: str
:return: Whether the given action is a movement action.
:rtype: bool
"""
return any([action == direction for direction in cls.movement_actions()])
@classmethod
def square_move(cls):
"""
Classmethod; return a list of movement actions that are considered square or `manhattan` style movements.
:return: A list of movement actions.
:rtype: list(str)
"""
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
@classmethod
def diagonal_move(cls):
"""
Classmethod; return a list of movement actions that are considered diagonal movements.
:return: A list of movement actions.
:rtype: list(str)
"""
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
@classmethod
def movement_actions(cls):
"""
Classmethod; return a list of all available movement actions.
Please note, that this is indipendent from the env. properties
:return: A list of movement actions.
:rtype: list(str)
"""
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
@classmethod
def resolve_movement_action_to_coords(cls, action):
"""
Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for.
How does the current entity coordinate change if it performs the given action?
Please note, this is indipendent from the env. properties
:return: Delta coorinates.
:rtype: tuple(int, int)
"""
return cls._ACTIONMAP[action]
class RewardsBase(NamedTuple):
"""
Value based mapping. Use these to define reward values for specific conditions (i.e. the action
in a given context), can be used globaly.
Please use class inheritance when defining new environments with new rewards.
"""
MOVEMENTS_VALID: float = -0.001
MOVEMENTS_FAIL: float = -0.05
NOOP: float = -0.01
@ -89,23 +171,31 @@ class RewardsBase(NamedTuple):
COLLISION: float = -0.5
m = EnvActions
c = Constants
ACTIONMAP = defaultdict(lambda: (0, 0),
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1),
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
}
)
class ObservationTranslator:
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
*per_agent_named_obs_space: Dict[str, dict],
*per_agent_named_obs_spaces: Dict[str, dict],
placeholder_fill_value: Union[int, str] = 'N'):
"""
This is a helper class, which converts agents observations from joined environments.
For example, agents trained in different environments may expect different observations.
This class translates from larger observations spaces to smaller.
A string identifier based approach is used.
Currently, it is not possible to mix different obs shapes.
:param obs_shape_2d: The shape of the observation the agents expect.
:type obs_shape_2d: tuple(int, int)
:param this_named_observation_space: `Named observation space` of the joined environment.
:type this_named_observation_space: Dict[str, dict]
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N')
"""
assert len(obs_shape_2d) == 2
self.obs_shape = obs_shape_2d
if isinstance(placeholder_fill_value, str):
@ -119,7 +209,7 @@ class ObservationTranslator:
self.random_fill = None
self._this_named_obs_space = this_named_observation_space
self._per_agent_named_obs_space = list(per_agent_named_obs_space)
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
def translate_observation(self, agent_idx: int, obs: np.ndarray):
target_obs_space = self._per_agent_named_obs_space[agent_idx]
@ -137,6 +227,19 @@ class ObservationTranslator:
class ActionTranslator:
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
"""
This is a helper class, which converts agents action spaces to a joined environments action space.
For example, agents trained in different environments may have different action spaces.
This class translates from smaller individual agent action spaces to larger joined spaces.
A string identifier based approach is used.
:param target_named_action_space: Joined `Named action space` for the current environment.
:type target_named_action_space: Dict[str, dict]
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
:type per_agent_named_action_space: Dict[str, dict]
"""
self._target_named_action_space = target_named_action_space
self._per_agent_named_action_space = list(per_agent_named_action_space)
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
@ -155,6 +258,16 @@ class ActionTranslator:
# Utility functions
def parse_level(path):
"""
Given the path to a strin based `level` or `map` representation, this function reads the content.
Cleans `space`, checks for equal length of each row and returns a list of lists.
:param path: Path to the `level` or `map` file on harddrive.
:type path: os.Pathlike
:return: The read string representation of the `level` or `map`
:rtype: List[List[str]]
"""
with path.open('r') as lvl:
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
if len(set([len(line) for line in level])) > 1:
@ -162,29 +275,56 @@ def parse_level(path):
return level
def one_hot_level(level, wall_char: str = c.WALL):
def one_hot_level(level, wall_char: str = Constants.WALL):
"""
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
Can be changed to filter for any symbol.
:param level: String based level representation (list of lists, see function `parse_level`).
:param wall_char: List[List[str]]
:return: Binary numpy array
:rtype: np.typing._array_like.ArrayLike
"""
grid = np.array(level)
binary_grid = np.zeros(grid.shape, dtype=np.int8)
binary_grid[grid == wall_char] = c.OCCUPIED_CELL
binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL
return binary_grid
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
"""
Given a slice (2-D Arraylike object)
:param slice_to_check_against: The slice to check for accessability
:type slice_to_check_against: np.typing._array_like.ArrayLike
:param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys.
:type position_to_check: tuple(int, int)
:return: Whether a position can be moved to.
:rtype: bool
"""
x_pos, y_pos = position_to_check
# Check if agent colides with grid boundrys
valid = not (
x_pos < 0 or y_pos < 0
or x_pos >= slice_to_check_against.shape[0]
or y_pos >= slice_to_check_against.shape[0]
or y_pos >= slice_to_check_against.shape[1]
)
# Check for collision with level walls
valid = valid and not slice_to_check_against[x_pos, y_pos]
return c.VALID if valid else c.NOT_VALID
return Constants.VALID if valid else Constants.NOT_VALID
def asset_str(agent):
"""
FIXME @ romue
"""
# What does this abonimation do?
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
# print('error')
@ -192,33 +332,50 @@ def asset_str(agent):
action = step_result['action_name']
valid = step_result['action_valid']
col_names = [x.name for x in step_result['collisions']]
if any(c.AGENT in name for name in col_names):
if any(Constants.AGENT in name for name in col_names):
return 'agent_collision', 'blank'
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
return c.AGENT, 'invalid'
elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names:
return Constants.AGENT, 'invalid'
elif valid and not EnvActions.is_move(action):
return c.AGENT, 'valid'
return Constants.AGENT, 'valid'
elif valid and EnvActions.is_move(action):
return c.AGENT, 'move'
return Constants.AGENT, 'move'
else:
return c.AGENT, 'idle'
return Constants.AGENT, 'idle'
else:
return c.AGENT, 'idle'
return Constants.AGENT, 'idle'
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
"""
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
There are three combinations of settings:
Allow all neigbors: Distance(a, b) <= sqrt(2)
Allow only manhattan: Distance(a, b) == 1
Allow only euclidean: Distance(a, b) == sqrt(2)
:param coordiniates_or_tiles: A set of coordinates.
:type coordiniates_or_tiles: Tiles
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
:type: bool
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
:type: bool
:return: A graph with nodes that are conneceted as specified by the parameters.
:rtype: nx.Graph
"""
assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'):
coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = abs(np.subtract(a, b))
if not max(diff) > 1:
if allow_manhattan_connections and allow_euclidean_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
graph.add_edge(a, b)
return graph

View File

@ -4,6 +4,22 @@ from gym.wrappers.frame_stack import FrameStack
class AgentRenderOptions(object):
"""
Class that specifies the available options for the way agents are represented in the env observation.
SEPERATE:
Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot)
COMBINED:
For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE))
LEVEL:
The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall)
NOT:
The position of individual agents can not be read from the observation.
"""
SEPERATE = 'seperate'
COMBINED = 'combined'
LEVEL = 'lvl'
@ -11,24 +27,61 @@ class AgentRenderOptions(object):
class MovementProperties(NamedTuple):
"""
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
"""
"""Allow the manhattan style movement on a grid (move to cells that are connected by square edges)."""
allow_square_movement: bool = True
"""Allow diagonal movement on the grid (move to cells that are connected by square corners)."""
allow_diagonal_movement: bool = False
"""Allow the agent to just do nothing; not move (NO-OP)."""
allow_no_op: bool = False
class ObservationProperties(NamedTuple):
# Todo: Add Description
"""
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
"""
"""How to represent agents in the observation space. This may also alters the obs-shape."""
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
omit_agent_self: bool = True
"""Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs.
The additional slice can be filled with any number"""
additional_agent_placeholder: Union[None, str, int] = None
"""Whether to cast shadows (make floortiles and items hidden).; """
cast_shadows: bool = True
"""Frame Stacking is a methode do give some temporal information to the agents.
This paramters controls how many "old-frames" """
frames_to_stack: int = 0
pomdp_r: int = 0
"""Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken
accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`.
A 'pomdp_r' of 0 always returns the full env == no partial observability."""
pomdp_r: int = 2
"""Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be
operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option.
However, this is not necesarry at all.
"""
indicate_door_area: bool = False
"""Whether to add the agents normalized global position as float values (2,1) to a seperate information slice.
More optional informations are to come.
"""
show_global_position_info: bool = False
class MarlFrameStack(gym.ObservationWrapper):
"""todo @romue404"""
def __init__(self, env):
super().__init__(env)

View File

@ -215,7 +215,7 @@ if __name__ == '__main__':
clean_amount=0.34,
max_spawn_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
dirt_smear_amount=0.0, agent_can_interact=True)
dirt_smear_amount=0.0)
item_props = ItemProperties(n_items=10,
spawn_frequency=30, n_drop_off_locations=2,
max_agent_inventory_capacity=15)
@ -349,6 +349,7 @@ if __name__ == '__main__':
# Env Init & Model kwargs definition
if model_cls.__name__ in ["PPO", "A2C"]:
# env_factory = env_class(**env_kwargs)
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
for _ in range(6)], start_method="spawn")
model_kwargs = policy_model_kwargs()

View File

@ -213,7 +213,8 @@ if __name__ == '__main__':
env_factory.save_params(param_path)
# EnvMonitor Init
callbacks = [EnvMonitor(env_factory)]
env_monitor = EnvMonitor(env_factory)
callbacks = [env_monitor]
# Model Init
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
@ -233,7 +234,7 @@ if __name__ == '__main__':
model.save(save_path)
# Monitor Save
callbacks[0].save_run(combination_path / 'monitor.pick',
env_monitor.save_run(combination_path / 'monitor.pick',
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
# Better be save then sorry: Clean up!