From 78bf19f7f4c6a2c1477c70d0723b26728bce10a1 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Thu, 23 Dec 2021 13:19:31 +0100 Subject: [PATCH] Item and Dirt Factory Working again --- algorithms/q_learner.py | 2 +- environments/factory/base/base_factory.py | 45 +++--- environments/factory/base/objects.py | 66 +++----- environments/factory/base/registers.py | 145 +++++++++--------- environments/factory/base/shadow_casting.py | 14 +- environments/factory/factory_battery.py | 7 +- environments/factory/factory_dest.py | 6 +- environments/factory/factory_dirt.py | 40 +++-- environments/factory/factory_item.py | 161 ++++++++------------ environments/helpers.py | 90 +++++------ environments/logging/recorder.py | 2 +- 11 files changed, 257 insertions(+), 321 deletions(-) diff --git a/algorithms/q_learner.py b/algorithms/q_learner.py index 53e891e..04f3a86 100644 --- a/algorithms/q_learner.py +++ b/algorithms/q_learner.py @@ -17,7 +17,7 @@ class QLearner(BaseLearner): self.q_net = q_net self.target_q_net = target_q_net self.target_q_net.eval() - #soft_update(self.q_net, self.target_q_net, tau=1.0) + #soft_update(cls.q_net, cls.target_q_net, tau=1.0) self.buffer = BaseBuffer(buffer_size) self.target_update = target_update self.eps = eps_start diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index a968cfd..9419c31 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -1,8 +1,6 @@ import abc -import enum import time from collections import defaultdict -from enum import Enum from itertools import chain from pathlib import Path from typing import List, Union, Iterable, Dict @@ -13,8 +11,8 @@ from gym import spaces from gym.wrappers import FrameStack from environments.factory.base.shadow_casting import Map -from environments.helpers import Constants as c, Constants from environments import helpers as h +from environments.helpers import Constants as c from environments.factory.base.objects import Agent, Tile, Action from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders, \ GlobalPositions @@ -53,10 +51,9 @@ class BaseFactory(gym.Env): _, named_obs = self._build_observations() if self.n_agents > 1: # Only return the first named obs space, as their structure at the moment is same. - return [{key.name: val for key, val in named_ob.items()} for named_ob in named_obs.values()][0] + return named_obs[list(named_obs.keys())[0]] else: - return {key.name: val for key, val in named_obs.items()} - + return named_obs @property def pomdp_diameter(self): @@ -143,27 +140,27 @@ class BaseFactory(gym.Env): # Walls walls = WallTiles.from_argwhere_coordinates( - np.argwhere(level_array == c.OCCUPIED_CELL.value), + np.argwhere(level_array == c.OCCUPIED_CELL), self._level_shape ) self._entities.register_additional_items({c.WALLS: walls}) # Floor floor = FloorTiles.from_argwhere_coordinates( - np.argwhere(level_array == c.FREE_CELL.value), + np.argwhere(level_array == c.FREE_CELL), self._level_shape ) self._entities.register_additional_items({c.FLOOR: floor}) # NOPOS - self._NO_POS_TILE = Tile(c.NO_POS.value, None) + self._NO_POS_TILE = Tile(c.NO_POS, None) # Doors if self.parse_doors: parsed_doors = h.one_hot_level(self._parsed_level, c.DOOR) parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0) if np.any(parsed_doors): - door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)] + door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL)] doors = Doors.from_tiles(door_tiles, self._level_shape, entity_kwargs=dict(context=floor) ) @@ -209,7 +206,7 @@ class BaseFactory(gym.Env): if self.obs_prop.show_global_position_info: global_positions = GlobalPositions(self._level_shape) obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2) - global_positions.spawn_GlobalPositionObjects(obs_shape_2d, self[c.AGENT]) + global_positions.spawn_global_position_objects(obs_shape_2d, self[c.AGENT]) self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions}) # Return @@ -239,8 +236,8 @@ class BaseFactory(gym.Env): for action, agent in zip(actions, self[c.AGENT]): agent.clear_temp_state() action_obj = self._actions[int(action)] - # self.print(f'Action #{action} has been resolved to: {action_obj}') - if h.MovingAction.is_member(action_obj): + # cls.print(f'Action #{action} has been resolved to: {action_obj}') + if h.EnvActions.is_move(action_obj): valid = self._move_or_colide(agent, action_obj) elif h.EnvActions.NOOP == agent.temp_action: valid = c.VALID @@ -338,12 +335,12 @@ class BaseFactory(gym.Env): obs_dict[c.AGENT_PLACEHOLDER] = placeholder_obs obs_dict[c.DOORS] = door_obs obs_dict.update(add_obs_dict) - observations = np.vstack(list(obs_dict.values())) + obsn = np.vstack(list(obs_dict.values())) if self.obs_prop.pomdp_r: - observations = self._do_pomdp_cutout(agent, observations) + obsn = self._do_pomdp_cutout(agent, obsn) - raw_obs = self._additional_raw_observations(agent) - observations = np.vstack((observations, *list(raw_obs.values()))) + raw_obs = self._additional_per_agent_raw_observations(agent) + obsn = np.vstack((obsn, *list(raw_obs.values()))) keys = list(chain(obs_dict.keys(), raw_obs.keys())) idxs = np.cumsum([x.shape[0] for x in chain(obs_dict.values(), raw_obs.values())]) - 1 @@ -365,7 +362,7 @@ class BaseFactory(gym.Env): print(e) raise e if self.obs_prop.cast_shadows: - obs_block_light = observations[light_block_obs] != c.OCCUPIED_CELL.value + obs_block_light = obsn[light_block_obs] != c.OCCUPIED_CELL door_shadowing = False if self.parse_doors: if doors := self[c.DOORS]: @@ -395,11 +392,11 @@ class BaseFactory(gym.Env): light_block_map[xs, ys] = 0 agent.temp_light_map = light_block_map.copy() - observations[shadowed_obs] = ((observations[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map) + obsn[shadowed_obs] = ((obsn[shadowed_obs] * light_block_map) + 0.) - (1 - light_block_map) else: pass - per_agent_obsn[agent.name] = observations + per_agent_obsn[agent.name] = obsn if self.n_agents == 1: agent_name = self[c.AGENT][0].name @@ -450,7 +447,7 @@ class BaseFactory(gym.Env): tiles_with_collisions.append(tile) return tiles_with_collisions - def _move_or_colide(self, agent: Agent, action: Action) -> Constants: + def _move_or_colide(self, agent: Agent, action: Action) -> bool: new_tile, valid = self._check_agent_move(agent, action) if valid: # Does not collide width level boundaries @@ -624,7 +621,7 @@ class BaseFactory(gym.Env): return [] @property - def additional_entities(self) -> Dict[(Enum, Entities)]: + def additional_entities(self) -> Dict[(str, Entities)]: """ When heriting from this Base Class, you musst implement this methode!!! @@ -652,11 +649,11 @@ class BaseFactory(gym.Env): return False @abc.abstractmethod - def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]: + def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: return {} @abc.abstractmethod - def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]: + def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: additional_raw_observations = {} if self.obs_prop.show_global_position_info: additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()}) diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index b7e392f..fd77efd 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -1,5 +1,4 @@ from collections import defaultdict -from enum import Enum from typing import Union import networkx as nx @@ -29,24 +28,18 @@ class Object: @property def identifier(self): - if self._enum_ident is not None: - return self._enum_ident - elif self._str_ident is not None: + if self._str_ident is not None: return self._str_ident else: return self._name - def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None, - is_blocking_light=False, **kwargs): + def __init__(self, str_ident: Union[str, None] = None, is_blocking_light=False, **kwargs): self._str_ident = str_ident - self._enum_ident = enum_ident - if self._enum_ident is not None and self._str_ident is None: - self._name = f'{self.__class__.__name__}[{self._enum_ident.name}]' - elif self._str_ident is not None and self._enum_ident is None: + if self._str_ident is not None: self._name = f'{self.__class__.__name__}[{self._str_ident}]' - elif self._str_ident is None and self._enum_ident is None: + elif self._str_ident is None: self._name = f'{self.__class__.__name__}#{Object._u_idx[self.__class__.__name__]}' Object._u_idx[self.__class__.__name__] += 1 else: @@ -60,16 +53,7 @@ class Object: return f'{self.name}' def __eq__(self, other) -> bool: - if self._enum_ident is not None: - if isinstance(other, Enum): - return other == self._enum_ident - elif isinstance(other, Object): - return other._enum_ident == self._enum_ident - else: - raise ValueError('Must be evaluated against an Enunm Identifier or Object with such.') - else: - assert isinstance(other, Object), ' This Object can only be compared to other Objects.' - return other.name == self.name + return other == self.identifier class EnvObject(Object): @@ -80,14 +64,17 @@ class EnvObject(Object): @property def encoding(self): - return c.OCCUPIED_CELL.value + return c.OCCUPIED_CELL def __init__(self, register, **kwargs): super(EnvObject, self).__init__(**kwargs) self._register = register + def change_register(self, register): + self._register = register -class BoundingMixin: + +class BoundingMixin(Object): @property def bound_entity(self): @@ -163,7 +150,7 @@ class MoveableEntity(Entity): if self._last_tile: return self._last_tile.pos else: - return c.NO_POS.value + return c.NO_POS @property def direction_of_view(self): @@ -218,30 +205,27 @@ class PlaceHolder(Object): return "PlaceHolder" -class GlobalPosition(EnvObject): +class GlobalPosition(EnvObject, BoundingMixin): - def belongs_to_entity(self, entity): - return self._agent == entity + @property + def encoding(self): + if self._normalized: + return tuple(np.diff(self._bound_entity.pos, self._level_shape)) + else: + return self.bound_entity.pos + + def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): + super(GlobalPosition, self).__init__(self, *args, **kwargs) - def __init__(self, level_shape, obs_shape, agent, normalized: bool = True): - super(GlobalPosition, self).__init__(self) - self._obs_shape = (1, *obs_shape) if len(obs_shape) == 2 else obs_shape - self._agent = agent self._level_shape = level_shape self._normalized = normalized - def as_array(self): - pos_array = np.zeros(self._obs_shape) - for xy in range(1): - pos_array[0, 0, xy] = self._agent.pos[xy] / self._level_shape[xy] - return pos_array - class Tile(EnvObject): @property def encoding(self): - return c.FREE_CELL.value + return c.FREE_CELL @property def guests_that_can_collide(self): @@ -302,7 +286,7 @@ class Wall(Tile): @property def encoding(self): - return c.OCCUPIED_CELL.value + return c.OCCUPIED_CELL pass @@ -319,7 +303,7 @@ class Door(Entity): @property def encoding(self): # This is important as it shadow is checked by occupation value - return c.OCCUPIED_CELL.value if self.is_closed else 2 + return c.OCCUPIED_CELL if self.is_closed else 2 @property def str_state(self): @@ -403,7 +387,7 @@ class Agent(MoveableEntity): # noinspection PyAttributeOutsideInit def clear_temp_state(self): - # for attr in self.__dict__: + # for attr in cls.__dict__: # if attr.startswith('temp'): self.temp_collisions = [] self.temp_valid = None diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index aff0105..89bbb5a 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -4,6 +4,7 @@ from abc import ABC from typing import List, Union, Dict, Tuple import numpy as np +import six from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \ Object, EnvObject @@ -56,7 +57,7 @@ class ObjectRegister: def _get_index(self, item): try: return next(i for i, v in enumerate(self._register.values()) if v == item) - except (StopIteration, AssertionError): + except StopIteration: return None def __getitem__(self, item): @@ -73,24 +74,30 @@ class ObjectRegister: return None def __repr__(self): - return f'{self.__class__.__name__}({self._register})' + return f'{self.__class__.__name__}[{self._register}]' class EnvObjectRegister(ObjectRegister): _accepted_objects = EnvObject - def __init__(self, obs_shape: (int, int), *args, **kwargs): + @property + def encodings(self): + return [x.encoding for x in self] + + def __init__(self, obs_shape: (int, int), *args, individual_slices: bool = False, **kwargs): super(EnvObjectRegister, self).__init__(*args, **kwargs) self._shape = obs_shape self._array = None - self.hide_from_obs_builder = False + self._individual_slices = individual_slices self._lazy_eval_transforms = [] def register_item(self, other: EnvObject): super(EnvObjectRegister, self).register_item(other) if self._array is None: self._array = np.zeros((1, *self._shape)) + if self._individual_slices: + self._array = np.vstack((self._array, np.zeros((1, *self._shape)))) self.notify_change_to_value(other) def as_array(self): @@ -105,7 +112,7 @@ class EnvObjectRegister(ObjectRegister): return [val.summarize_state(n_steps=n_steps) for val in self.values()] def notify_change_to_free(self, env_object: EnvObject): - self._array_change_notifyer(env_object, value=c.FREE_CELL.value) + self._array_change_notifyer(env_object, value=c.FREE_CELL) def notify_change_to_value(self, env_object: EnvObject): self._array_change_notifyer(env_object) @@ -114,9 +121,28 @@ class EnvObjectRegister(ObjectRegister): pos = self._get_index(env_object) value = value if value is not None else env_object.encoding self._lazy_eval_transforms.append((pos, value)) + if self._individual_slices: + idx = (self._get_index(env_object) * np.prod(self._shape[1:]), value) + self._lazy_eval_transforms.append((idx, value)) + else: + self._lazy_eval_transforms.append((pos, value)) + + def _refresh_arrays(self): + poss, values = zip(*[(idx, x.encoding) for idx,x in enumerate(self.values())]) + for pos, value in zip(poss, values): + self._lazy_eval_transforms.append((pos, value)) def __delitem__(self, name): - self.notify_change_to_free(self._register[name]) + idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) + if self._individual_slices: + self._array = np.delete(self._array, idx, axis=0) + else: + self.notify_change_to_free(self._register[name]) + # Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions + # in the observation array are result of enumeration. They can overide each other. + # Todo: Find a better solution + if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister): + self._refresh_arrays() del self._register[name] def delete_env_object(self, env_object: EnvObject): @@ -153,26 +179,19 @@ class EntityRegister(EnvObjectRegister, ABC): def tiles(self): return [entity.tile for entity in self] - @property - def encodings(self): - return [x.encoding for x in self] - def __init__(self, level_shape, *args, is_blocking_light: bool = False, can_be_shadowed: bool = True, - individual_slices: bool = False, **kwargs): + **kwargs): super(EntityRegister, self).__init__(level_shape, *args, **kwargs) self._lazy_eval_transforms = [] self.can_be_shadowed = can_be_shadowed - self.individual_slices = individual_slices self.is_blocking_light = is_blocking_light def __delitem__(self, name): idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name) obj.tile.leave(obj) super(EntityRegister, self).__delitem__(name) - if self.individual_slices: - self._array = np.delete(self._array, idx, axis=0) def as_array(self): if self._lazy_eval_transforms: @@ -188,7 +207,7 @@ class EntityRegister(EnvObjectRegister, ABC): pos = pos if pos is not None else entity.pos value = value if value is not None else entity.encoding x, y = pos - if self.individual_slices: + if self._individual_slices: idx = (self._get_index(entity), x, y) else: idx = (0, x, y) @@ -203,19 +222,12 @@ class EntityRegister(EnvObjectRegister, ABC): class BoundRegisterMixin(EnvObjectRegister, ABC): - @classmethod - def from_entities_to_bind(self, entitites): - def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]], - *args, object_kwargs=None, **kwargs): - # objects_name = cls._accepted_objects.__name__ - if isinstance(values, (str, numbers.Number)): - values = [values] - register_obj = cls(*args, **kwargs) - objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value, - **object_kwargs if object_kwargs is not None else {}) - for i, value in enumerate(values)] - register_obj.register_additional_items(objects) - return register_obj + def __init__(self, entity_to_be_bound, *args, **kwargs): + super().__init__(*args, **kwargs) + self._bound_entity = entity_to_be_bound + + def belongs_to_entity(self, entity): + return self._bound_entity == entity class MovingEntityObjectRegister(EntityRegister, ABC): @@ -225,9 +237,9 @@ class MovingEntityObjectRegister(EntityRegister, ABC): def notify_change_to_value(self, entity): super(MovingEntityObjectRegister, self).notify_change_to_value(entity) - if entity.last_pos != c.NO_POS.value: + if entity.last_pos != c.NO_POS: try: - self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL.value) + self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL) except AttributeError: pass @@ -238,20 +250,26 @@ class MovingEntityObjectRegister(EntityRegister, ABC): class GlobalPositions(EnvObjectRegister): + _accepted_objects = GlobalPosition + is_blocking_light = False can_be_shadowed = False - hide_from_obs_builder = True def __init__(self, *args, **kwargs): super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) def as_array(self): - # Todo make this lazy? + # FIXME DEBUG!!! make this lazy? return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)]) - def spawn_GlobalPositionObjects(self, obs_shape, agents): - global_positions = [self._accepted_objects(self._shape, obs_shape, agent) + def as_array_by_entity(self, entity): + # FIXME DEBUG!!! make this lazy? + return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)]) + + def spawn_global_position_objects(self, agents): + # Todo, change to 'from xy'-form + global_positions = [self._accepted_objects(self._shape, agent) for _, agent in enumerate(agents)] # noinspection PyTypeChecker self.register_additional_items(global_positions) @@ -276,7 +294,7 @@ class PlaceHolders(EnvObjectRegister): _accepted_objects = PlaceHolder def __init__(self, *args, **kwargs): - assert not 'individual_slices' in kwargs, 'Keyword - "individual_slices": "True" and must not be altered' + assert 'individual_slices' not in kwargs, 'Keyword - "individual_slices": "True" and must not be altered' kwargs.update(individual_slices=False) super().__init__(*args, **kwargs) @@ -316,10 +334,6 @@ class Entities(ObjectRegister): def arrays(self): return {key: val.as_array() for key, val in self.items()} - @property - def obs_arrays(self): - return {key: val.as_array() for key, val in self.items() if not val.hide_from_obs_builder} - @property def names(self): return list(self._register.keys()) @@ -347,24 +361,20 @@ class Entities(ObjectRegister): class WallTiles(EntityRegister): _accepted_objects = Wall _light_blocking = True - hide_from_obs_builder = True def as_array(self): if not np.any(self._array): # Which is Faster? - # indices = [x.pos for x in self] - # np.put(self._array, [np.ravel_multi_index((0, *x), self._array.shape) for x in indices], self.encodings) + # indices = [x.pos for x in cls] + # np.put(cls._array, [np.ravel_multi_index((0, *x), cls._array.shape) for x in indices], cls.encodings) x, y = zip(*[x.pos for x in self]) - self._array[0, x, y] = self.encoding + self._array[0, x, y] = self._value return self._array def __init__(self, *args, **kwargs): super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False, **kwargs) - - @property - def encoding(self): - return c.OCCUPIED_CELL.value + self._value = c.OCCUPIED_CELL @classmethod def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs): @@ -393,10 +403,7 @@ class FloorTiles(WallTiles): def __init__(self, *args, **kwargs): super(FloorTiles, self).__init__(*args, **kwargs) - - @property - def encoding(self): - return c.FREE_CELL.value + self._value = c.FREE_CELL @property def occupied_tiles(self): @@ -422,23 +429,8 @@ class FloorTiles(WallTiles): class Agents(MovingEntityObjectRegister): _accepted_objects = Agent - def __init__(self, *args, hide_from_obs_builder=False, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.hide_from_obs_builder = hide_from_obs_builder - - @DeprecationWarning - def Xas_array(self): - # Super Safe Version - # self._array[:] = c.FREE_CELL.value - indices = list(zip(range(len(self)), *zip(*[x.last_pos for x in self]))) - np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], c.FREE_CELL.value) - indices = list(zip(range(len(self)), *zip(*[x.pos for x in self]))) - np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings) - - if self.individual_slices: - return self._array - else: - return self._array.sum(axis=0, keepdims=True) @property def positions(self): @@ -484,17 +476,18 @@ class Actions(ObjectRegister): self.can_use_doors = can_use_doors super(Actions, self).__init__() + # Move this to Baseclass, Env init? if self.allow_square_movement: - self.register_additional_items([self._accepted_objects(enum_ident=direction) - for direction in h.MovingAction.square()]) + self.register_additional_items([self._accepted_objects(str_ident=direction) + for direction in h.EnvActions.square_move()]) if self.allow_diagonal_movement: - self.register_additional_items([self._accepted_objects(enum_ident=direction) - for direction in h.MovingAction.diagonal()]) + self.register_additional_items([self._accepted_objects(str_ident=direction) + for direction in h.EnvActions.diagonal_move()]) self._movement_actions = self._register.copy() if self.can_use_doors: - self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)]) + self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)]) if self.allow_no_op: - self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.NOOP)]) + self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)]) def is_moving_action(self, action: Union[int]): return action in self.movement_actions.values() @@ -504,7 +497,7 @@ class Zones(ObjectRegister): @property def accounting_zones(self): - return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value] + return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE] def __init__(self, parsed_level): raise NotImplementedError('This needs a Rework') @@ -513,9 +506,9 @@ class Zones(ObjectRegister): self._accounting_zones = list() self._danger_zones = list() for symbol in np.unique(parsed_level): - if symbol == c.WALL.value: + if symbol == c.WALL: continue - elif symbol == c.DANGER_ZONE.value: + elif symbol == c.DANGER_ZONE: self + symbol slices.append(h.one_hot_level(parsed_level, symbol)) self._danger_zones.append(symbol) diff --git a/environments/factory/base/shadow_casting.py b/environments/factory/base/shadow_casting.py index fd6471a..3fdd0b6 100644 --- a/environments/factory/base/shadow_casting.py +++ b/environments/factory/base/shadow_casting.py @@ -16,14 +16,14 @@ class Map(object): def __init__(self, map_array: np.typing.ArrayLike, diamond_slope: float = 0.9): self.data = map_array self.width, self.height = map_array.shape - self.light = np.full_like(self.data, c.FREE_CELL.value) - self.flag = c.FREE_CELL.value + self.light = np.full_like(self.data, c.FREE_CELL) + self.flag = c.FREE_CELL self.d_slope = diamond_slope def blocked(self, x, y): return (x < 0 or y < 0 or x >= self.width or y >= self.height - or self.data[x, y] == c.OCCUPIED_CELL.value) + or self.data[x, y] == c.OCCUPIED_CELL) def lit(self, x, y): return self.light[x, y] == self.flag @@ -46,14 +46,14 @@ class Map(object): # Translate the dx, dy coordinates into map coordinates: X, Y = cx + dx * xx + dy * xy, cy + dx * yx + dy * yy # l_slope and r_slope store the slopes of the left and right - # extremities of the square we're considering: + # extremities of the square_move we're considering: l_slope, r_slope = (dx-self.d_slope)/(dy+self.d_slope), (dx+self.d_slope)/(dy-self.d_slope) if start < r_slope: continue elif end > l_slope: break else: - # Our light beam is touching this square; light it: + # Our light beam is touching this square_move; light it: if dx*dx + dy*dy < radius_squared: self.set_lit(X, Y) if blocked: @@ -66,12 +66,12 @@ class Map(object): start = new_start else: if self.blocked(X, Y) and j < radius: - # This is a blocking square, start a child scan: + # This is a blocking square_move, start a child scan: blocked = True self._cast_light(cx, cy, j+1, start, l_slope, radius, xx, xy, yx, yy, id+1) new_start = r_slope - # Row is scanned; do next row unless last square was blocked: + # Row is scanned; do next row unless last square_move was blocked: if blocked: break diff --git a/environments/factory/factory_battery.py b/environments/factory/factory_battery.py index ee27bc5..f6c57bd 100644 --- a/environments/factory/factory_battery.py +++ b/environments/factory/factory_battery.py @@ -65,7 +65,6 @@ class BatteriesRegister(EnvObjectRegister): _accepted_objects = Battery is_blocking_light = False can_be_shadowed = False - hide_from_obs_builder = True def __init__(self, *args, **kwargs): super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) @@ -98,7 +97,7 @@ class BatteriesRegister(EnvObjectRegister): def summarize_states(self, n_steps=None): # as dict with additional nesting - # return dict(items=super(Inventories, self).summarize_states()) + # return dict(items=super(Inventories, cls).summarize_states()) return super(BatteriesRegister, self).summarize_states(n_steps=n_steps) @@ -156,8 +155,8 @@ class BatteryFactory(BaseFactory): self.btry_prop = btry_prop super().__init__(*args, **kwargs) - def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]: - additional_raw_observations = super()._additional_raw_observations(agent) + def _additional_per_agent_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]: + additional_raw_observations = super()._additional_per_agent_raw_observations(agent) additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()}) return additional_raw_observations diff --git a/environments/factory/factory_dest.py b/environments/factory/factory_dest.py index b34ef20..24e4b7f 100644 --- a/environments/factory/factory_dest.py +++ b/environments/factory/factory_dest.py @@ -14,6 +14,8 @@ from environments.factory.base.registers import Entities, EntityRegister from environments.factory.base.renderer import RenderEntity + + DESTINATION = 1 DESTINATION_DONE = 0.5 @@ -70,8 +72,8 @@ class Destinations(EntityRegister): def as_array(self): self._array[:] = c.FREE_CELL.value # ToDo: Switch to new Style Array Put - # indices = list(zip(range(len(self)), *zip(*[x.pos for x in self]))) - # np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings) + # indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls]))) + # np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings) for item in self: if item.pos != c.NO_POS.value: self._array[0, item.x, item.y] = item.encoding diff --git a/environments/factory/factory_dirt.py b/environments/factory/factory_dirt.py index ff28840..65d3390 100644 --- a/environments/factory/factory_dirt.py +++ b/environments/factory/factory_dirt.py @@ -5,24 +5,31 @@ import random import numpy as np -from algorithms.TSP_dirt_agent import TSPDirtAgent -from environments.helpers import Constants as c, Constants -from environments import helpers as h +# from algorithms.TSP_dirt_agent import TSPDirtAgent +from environments.helpers import Constants as BaseConstants +from environments.helpers import EnvActions as BaseActions + from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action, Entity, Tile -from environments.factory.base.registers import Entities, MovingEntityObjectRegister, EntityRegister +from environments.factory.base.registers import Entities, EntityRegister from environments.factory.base.renderer import RenderEntity from environments.utility_classes import ObservationProperties -CLEAN_UP_ACTION = h.EnvActions.CLEAN_UP + +class Constants(BaseConstants): + DIRT = 'Dirt' + + +class EnvActions(BaseActions): + CLEAN_UP = 'clean_up' class DirtProperties(NamedTuple): - initial_dirt_ratio: float = 0.3 # On INIT, on max how much tiles does the dirt spawn in percent. + initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent. initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary? clean_amount: float = 1 # How much does the robot clean with one actions. - max_spawn_ratio: float = 0.20 # On max how much tiles does the dirt spawn in percent. + max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent. max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max. spawn_frequency: int = 0 # Spawn Frequency in Steps. max_local_amount: int = 2 # Max dirt amount per tile. @@ -77,7 +84,7 @@ class DirtRegister(EntityRegister): super(DirtRegister, self).__init__(*args) self._dirt_properties: DirtProperties = dirt_properties - def spawn_dirt(self, then_dirty_tiles) -> c: + def spawn_dirt(self, then_dirty_tiles) -> bool: if isinstance(then_dirty_tiles, Tile): then_dirty_tiles = [then_dirty_tiles] for tile in then_dirty_tiles: @@ -108,6 +115,9 @@ def entropy(x): return -(x * np.log(x + 1e-8)).sum() +c = Constants + + # noinspection PyAttributeOutsideInit, PyAbstractClass class DirtFactory(BaseFactory): @@ -115,7 +125,7 @@ class DirtFactory(BaseFactory): def additional_actions(self) -> Union[Action, List[Action]]: super_actions = super().additional_actions if self.dirt_prop.agent_can_interact: - super_actions.append(Action(enum_ident=CLEAN_UP_ACTION)) + super_actions.append(Action(str_ident=EnvActions.CLEAN_UP)) return super_actions @property @@ -194,7 +204,7 @@ class DirtFactory(BaseFactory): def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: valid = super().do_additional_actions(agent, action) if valid is None: - if action == CLEAN_UP_ACTION: + if action == EnvActions.CLEAN_UP: if self.dirt_prop.agent_can_interact: valid = self.clean_up(agent) return valid @@ -215,7 +225,7 @@ class DirtFactory(BaseFactory): done = self.dirt_prop.done_when_clean and (len(self[c.DIRT]) == 0) return super_done or done - def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]: + def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: additional_observations = super()._additional_observations() additional_observations.update({c.DIRT: self[c.DIRT].as_array()}) return additional_observations @@ -227,14 +237,14 @@ class DirtFactory(BaseFactory): dirty_tile_count = len(dirt) # if dirty_tile_count: # dirt_distribution_score = entropy(softmax(np.asarray(dirt)) / dirty_tile_count) - #else: + # else: # dirt_distribution_score = 0 info_dict.update(dirt_amount=current_dirt_amount) info_dict.update(dirty_tile_count=dirty_tile_count) # info_dict.update(dirt_distribution_score=dirt_distribution_score) - if agent.temp_action == CLEAN_UP_ACTION: + if agent.temp_action == EnvActions.CLEAN_UP: if agent.temp_valid: # Reward if pickup succeds, # 0.5 on every pickup @@ -257,7 +267,7 @@ class DirtFactory(BaseFactory): if __name__ == '__main__': - from environments.utility_classes import AgentRenderOptions as ARO + from environments.utility_classes import AgentRenderOptions as aro render = True dirt_props = DirtProperties( @@ -273,7 +283,7 @@ if __name__ == '__main__': agent_can_interact=True ) - obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True, + obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True, pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True) move_props = {'allow_square_movement': True, diff --git a/environments/factory/factory_item.py b/environments/factory/factory_item.py index 1ddaa93..5bbc867 100644 --- a/environments/factory/factory_item.py +++ b/environments/factory/factory_item.py @@ -1,22 +1,30 @@ import time -from collections import deque, UserList -from enum import Enum +from collections import deque from typing import List, Union, NamedTuple, Dict import numpy as np import random from environments.factory.base.base_factory import BaseFactory -from environments.helpers import Constants as c, Constants +from environments.helpers import Constants as BaseConstants +from environments.helpers import EnvActions as BaseActions from environments import helpers as h -from environments.factory.base.objects import Agent, Entity, Action, Tile, MoveableEntity -from environments.factory.base.registers import Entities, EntityRegister, EnvObjectRegister, MovingEntityObjectRegister, \ - BoundRegisterMixin +from environments.factory.base.objects import Agent, Entity, Action, Tile +from environments.factory.base.registers import Entities, EntityRegister, BoundRegisterMixin, ObjectRegister from environments.factory.base.renderer import RenderEntity -NO_ITEM = 0 -ITEM_DROP_OFF = 1 +class Constants(BaseConstants): + NO_ITEM = 0 + ITEM_DROP_OFF = 1 + # Item Env + ITEM = 'Item' + INVENTORY = 'Inventory' + DROP_OFF = 'Drop_Off' + + +class EnvActions(BaseActions): + ITEM_ACTION = 'item_action' class Item(Entity): @@ -41,13 +49,9 @@ class Item(Entity): def set_auto_despawn(self, auto_despawn): self._auto_despawn = auto_despawn - def despawn(self): - # Todo: Move this to base class? - curr_tile = self.tile - curr_tile.leave(self) - self._tile = None - self._register.notify_change_to_value(self) - return True + def set_tile_to(self, no_pos_tile): + assert self._register.__class__.__name__ != ItemRegister.__class__ + self._tile = no_pos_tile class ItemRegister(EntityRegister): @@ -64,58 +68,38 @@ class ItemRegister(EntityRegister): del self[item] -class Inventory(EntityRegister, BoundRegisterMixin): - - @property - def is_blocking_light(self): - return False +class Inventory(BoundRegisterMixin): @property def name(self): - return f'{self.__class__.__name__}({self.agent.name})' + return f'{self.__class__.__name__}({self._bound_entity.name})' - def __init__(self, obs_shape: (int, int), agent: Agent, capacity: int): - super(Inventory, self).__init__() - self.agent = agent - self._obs_shape = obs_shape - - self._array = np.zeros((1, *self._obs_shape)) - - self.capacity = min(capacity, self._array.size) + def __init__(self, agent: Agent, capacity: int, *args, **kwargs): + super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs) + self.capacity = capacity def as_array(self): - self._array[:] = c.FREE_CELL.value - # ToDo: Make this Lazy - for item_idx, item in enumerate(self): - x_diff, y_diff = divmod(item_idx, self._array.shape[1]) - self._array[0, int(x_diff), int(y_diff)] = item.encoding - return self._array + if self._array is None: + self._array = np.zeros((1, *self._shape)) + return super(Inventory, self).as_array() - def __repr__(self): - return f'{self.__class__.__name__}[{self.agent.name}]({self.data})' - - def append(self, item) -> None: - if len(self) < self.capacity: - super(Inventory, self).append(item) - else: - raise RuntimeError('Inventory is full') - - def belongs_to_entity(self, entity): - return self.agent == entity - - def summarize_state(self, **kwargs): + def summarize_states(self, **kwargs): attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} - attr_dict.update(dict(items={val.name: val.summarize_state(**kwargs) for val in self})) + attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()})) attr_dict.update(dict(name=self.name)) return attr_dict + def pop(self): + item_to_pop = self[0] + self.delete_env_object(item_to_pop) + return item_to_pop -class Inventories(EnvObjectRegister): + +class Inventories(ObjectRegister): _accepted_objects = Inventory is_blocking_light = False can_be_shadowed = False - hide_from_obs_builder = True def __init__(self, obs_shape, *args, **kwargs): super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) @@ -125,7 +109,7 @@ class Inventories(EnvObjectRegister): return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)]) def spawn_inventories(self, agents, capacity): - inventories = [self._accepted_objects(self._obs_shape, agent, capacity) + inventories = [self._accepted_objects(agent, capacity, self._obs_shape) for _, agent in enumerate(agents)] self.register_additional_items(inventories) @@ -141,10 +125,8 @@ class Inventories(EnvObjectRegister): except StopIteration: return None - def summarize_states(self, n_steps=None): - # as dict with additional nesting - # return dict(items=super(Inventories, self).summarize_states()) - return super(Inventories, self).summarize_states(n_steps=n_steps) + def summarize_states(self, **kwargs): + return {key: val.summarize_states(**kwargs) for key, val in self.items()} class DropOffLocation(Entity): @@ -155,7 +137,7 @@ class DropOffLocation(Entity): @property def encoding(self): - return ITEM_DROP_OFF + return Constants.ITEM_DROP_OFF def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs): super(DropOffLocation, self).__init__(*args, **kwargs) @@ -184,24 +166,17 @@ class DropOffLocations(EntityRegister): _accepted_objects = DropOffLocation - @DeprecationWarning - def Xas_array(self): - # Todo: Which is faster? - # indices = list(zip(range(len(self)), *zip(*[x.pos for x in self]))) - # np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings) - self._array[:] = c.FREE_CELL.value - indices = list(zip([0, ] * len(self), *zip(*[x.pos for x in self]))) - np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings) - return self._array - class ItemProperties(NamedTuple): - n_items: int = 5 # How many items are there at the same time - spawn_frequency: int = 10 # Spawn Frequency in Steps - n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time - max_dropoff_storage_size: int = 0 # How many items are needed until the drop off is full + n_items: int = 5 # How many items are there at the same time + spawn_frequency: int = 10 # Spawn Frequency in Steps + n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time + max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full - agent_can_interact: bool = True # Whether agents have the possibility to interact with the domain items + + +c = Constants +a = EnvActions # noinspection PyAttributeOutsideInit, PyAbstractClass @@ -220,11 +195,11 @@ class ItemFactory(BaseFactory): def additional_actions(self) -> Union[Action, List[Action]]: # noinspection PyUnresolvedReferences super_actions = super().additional_actions - super_actions.append(Action(enum_ident=h.EnvActions.ITEM_ACTION)) + super_actions.append(Action(str_ident=a.ITEM_ACTION)) return super_actions @property - def additional_entities(self) -> Dict[(Enum, Entities)]: + def additional_entities(self) -> Dict[(str, Entities)]: # noinspection PyUnresolvedReferences super_entities = super().additional_entities @@ -238,19 +213,18 @@ class ItemFactory(BaseFactory): empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items] item_register.spawn_items(empty_tiles) - inventories = Inventories(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), - self._level_shape) + inventories = Inventories(self._obs_shape, self._level_shape) inventories.spawn_inventories(self[c.AGENT], self.item_prop.max_agent_inventory_capacity) super_entities.update({c.DROP_OFF: drop_offs, c.ITEM: item_register, c.INVENTORY: inventories}) return super_entities - def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]: - additional_raw_observations = super()._additional_raw_observations(agent) + def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: + additional_raw_observations = super()._additional_per_agent_raw_observations(agent) additional_raw_observations.update({c.INVENTORY: self[c.INVENTORY].by_entity(agent).as_array()}) return additional_raw_observations - def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]: + def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]: additional_observations = super()._additional_observations() additional_observations.update({c.ITEM: self[c.ITEM].as_array()}) additional_observations.update({c.DROP_OFF: self[c.DROP_OFF].as_array()}) @@ -260,14 +234,16 @@ class ItemFactory(BaseFactory): inventory = self[c.INVENTORY].by_entity(agent) if drop_off := self[c.DROP_OFF].by_pos(agent.pos): if inventory: - valid = drop_off.place_item(inventory.pop(0)) + valid = drop_off.place_item(inventory.pop()) return valid else: return c.NOT_VALID elif item := self[c.ITEM].by_pos(agent.pos): try: - inventory.append(item) - item.despawn() + inventory.register_item(item) + item.change_register(inventory) + self[c.ITEM].delete_env_object(item) + item.set_tile_to(self._NO_POS_TILE) return c.VALID except RuntimeError: return c.NOT_VALID @@ -278,12 +254,9 @@ class ItemFactory(BaseFactory): # noinspection PyUnresolvedReferences valid = super().do_additional_actions(agent, action) if valid is None: - if action == h.EnvActions.ITEM_ACTION: - if self.item_prop.agent_can_interact: - valid = self.do_item_action(agent) - return valid - else: - return c.NOT_VALID + if action == a.ITEM_ACTION: + valid = self.do_item_action(agent) + return valid else: return None else: @@ -324,7 +297,7 @@ class ItemFactory(BaseFactory): def calculate_additional_reward(self, agent: Agent) -> (int, dict): # noinspection PyUnresolvedReferences reward, info_dict = super().calculate_additional_reward(agent) - if h.EnvActions.ITEM_ACTION == agent.temp_action: + if a.ITEM_ACTION == agent.temp_action: if agent.temp_valid: if drop_off := self[c.DROP_OFF].by_pos(agent.pos): info_dict.update({f'{agent.name}_item_drop_off': 1}) @@ -352,21 +325,21 @@ class ItemFactory(BaseFactory): def render_additional_assets(self, mode='human'): # noinspection PyUnresolvedReferences additional_assets = super().render_additional_assets() - items = [RenderEntity(c.ITEM.value, item.tile.pos) for item in self[c.ITEM]] + items = [RenderEntity(c.ITEM, item.tile.pos) for item in self[c.ITEM] if item.tile != self._NO_POS_TILE] additional_assets.extend(items) - drop_offs = [RenderEntity(c.DROP_OFF.value, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]] + drop_offs = [RenderEntity(c.DROP_OFF, drop_off.tile.pos) for drop_off in self[c.DROP_OFF]] additional_assets.extend(drop_offs) return additional_assets if __name__ == '__main__': - from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties + from environments.utility_classes import AgentRenderOptions as aro, ObservationProperties render = True - item_probs = ItemProperties() + item_probs = ItemProperties(n_items=30) - obs_props = ObservationProperties(render_agents=ARO.SEPERATE, omit_agent_self=True, pomdp_r=2) + obs_props = ObservationProperties(render_agents=aro.SEPERATE, omit_agent_self=True, pomdp_r=2) move_props = {'allow_square_movement': True, 'allow_diagonal_movement': True, diff --git a/environments/helpers.py b/environments/helpers.py index b6a2550..6cc555a 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -1,7 +1,5 @@ import itertools from collections import defaultdict -from enum import Enum -from pathlib import Path from typing import Tuple, Union, Dict, List import networkx as nx @@ -20,7 +18,7 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amo # Constants -class Constants(Enum): +class Constants: WALL = '#' WALLS = 'Walls' FLOOR = 'Floor' @@ -44,14 +42,6 @@ class Constants(Enum): VALID = 'valid' NOT_VALID = 'not_valid' - # Dirt Env - DIRT = 'Dirt' - - # Item Env - ITEM = 'Item' - INVENTORY = 'Inventory' - DROP_OFF = 'Drop_Off' - # Battery Env CHARGE_POD = 'Charge_Pod' BATTERIES = 'BATTERIES' @@ -60,14 +50,9 @@ class Constants(Enum): DESTINATION = 'Destination' REACHEDDESTINATION = 'ReachedDestination' - def __bool__(self): - if 'not_' in self.value: - return False - else: - return bool(self.value) - -class MovingAction(Enum): +class EnvActions: + # Movements NORTH = 'north' EAST = 'east' SOUTH = 'south' @@ -77,29 +62,31 @@ class MovingAction(Enum): SOUTHWEST = 'south_west' NORTHWEST = 'north_west' - @classmethod - def is_member(cls, other): - return any([other == direction for direction in cls]) - - @classmethod - def square(cls): - return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] - - @classmethod - def diagonal(cls): - return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] - - -class EnvActions(Enum): - NOOP = 'no_op' + # Other + NOOP = 'no_op' USE_DOOR = 'use_door' - CLEAN_UP = 'clean_up' - ITEM_ACTION = 'item_action' + CHARGE = 'charge' WAIT_ON_DEST = 'wait' + @classmethod + def is_move(cls, other): + return any([other == direction for direction in cls.movement_actions()]) -m = MovingAction + @classmethod + def square_move(cls): + return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST] + + @classmethod + def diagonal_move(cls): + return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST] + + @classmethod + def movement_actions(cls): + return list(itertools.chain(cls.square_move(), cls.diagonal_move())) + + +m = EnvActions c = Constants ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, +1), @@ -171,13 +158,10 @@ def parse_level(path): return level -def one_hot_level(level, wall_char: Union[c, str] = c.WALL): +def one_hot_level(level, wall_char: str = c.WALL): grid = np.array(level) binary_grid = np.zeros(grid.shape, dtype=np.int8) - if wall_char in c: - binary_grid[grid == wall_char.value] = c.OCCUPIED_CELL.value - else: - binary_grid[grid == wall_char] = c.OCCUPIED_CELL.value + binary_grid[grid == wall_char] = c.OCCUPIED_CELL return binary_grid @@ -198,19 +182,19 @@ def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[i def asset_str(agent): # What does this abonimation do? - # if any([x is None for x in [self._slices[j] for j in agent.collisions]]): + # if any([x is None for x in [cls._slices[j] for j in agent.collisions]]): # print('error') col_names = [x.name for x in agent.temp_collisions] - if any(c.AGENT.value in name for name in col_names): + if any(c.AGENT in name for name in col_names): return 'agent_collision', 'blank' - elif not agent.temp_valid or c.LEVEL.name in col_names or c.AGENT.name in col_names: - return c.AGENT.value, 'invalid' - elif agent.temp_valid and not MovingAction.is_member(agent.temp_action): - return c.AGENT.value, 'valid' - elif agent.temp_valid and MovingAction.is_member(agent.temp_action): - return c.AGENT.value, 'move' + elif not agent.temp_valid or c.LEVEL in col_names or c.AGENT in col_names: + return c.AGENT, 'invalid' + elif agent.temp_valid and not EnvActions.is_move(agent.temp_action): + return c.AGENT, 'valid' + elif agent.temp_valid and EnvActions.is_move(agent.temp_action): + return c.AGENT, 'move' else: - return c.AGENT.value, 'idle' + return c.AGENT, 'idle' def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): @@ -229,9 +213,3 @@ def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, all elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff): graph.add_edge(a, b) return graph - - -if __name__ == '__main__': - parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt') - y = one_hot_level(parsed_level) - print(np.argwhere(y == 0)) diff --git a/environments/logging/recorder.py b/environments/logging/recorder.py index 38569ca..7f53da8 100644 --- a/environments/logging/recorder.py +++ b/environments/logging/recorder.py @@ -60,7 +60,7 @@ class EnvRecorder(BaseCallback): def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False): filepath = Path(filepath) filepath.parent.mkdir(exist_ok=True, parents=True) - # self.out_file.unlink(missing_ok=True) + # cls.out_file.unlink(missing_ok=True) with filepath.open('w') as f: out_dict = {'episodes': self._recorder_out_list, 'header': self.unwrapped.params} try: