From a16d7e709e118a40c5b802b0a2840eaeff6c6fc9 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 18 Jan 2022 11:39:19 +0100 Subject: [PATCH] Door Area Indicators --- environments/factory/base/base_factory.py | 8 ++++++-- environments/factory/base/objects.py | 2 +- environments/factory/base/registers.py | 16 +++++++++++++++- environments/factory/factory_battery.py | 2 +- environments/factory/factory_dirt.py | 8 ++++++-- environments/helpers.py | 5 ++++- environments/utility_classes.py | 2 +- 7 files changed, 34 insertions(+), 9 deletions(-) diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 613a5e3..f852d7c 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -68,6 +68,7 @@ class BaseFactory(gym.Env): @property def params(self) -> dict: d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')} + d['class_name'] = self.__class__.__name__ return d def __enter__(self): @@ -83,7 +84,10 @@ class BaseFactory(gym.Env): rewards_base: RewardsBase = RewardsBase(), parse_doors=False, done_at_collision=False, inject_agents: Union[None, List] = None, verbose=False, doors_have_area=True, env_seed=time.time_ns(), individual_rewards=False, - **kwargs): + class_name='', **kwargs): + + if class_name: + print(f'You loaded parameters for {class_name}', f'this is: {self.__class__.__name__}') if isinstance(mv_prop, dict): mv_prop = MovementProperties(**mv_prop) @@ -167,7 +171,7 @@ class BaseFactory(gym.Env): parsed_doors = np.pad(parsed_doors, self.obs_prop.pomdp_r, 'constant', constant_values=0) if np.any(parsed_doors): door_tiles = [floor.by_pos(tuple(pos)) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL)] - doors = Doors.from_tiles(door_tiles, self._level_shape, + doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.doors_have_area, entity_kwargs=dict(context=floor) ) self._entities.register_additional_items({c.DOORS: doors}) diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index 994136f..d8837fd 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -311,7 +311,7 @@ class Door(Entity): @property def encoding(self): # This is important as it shadow is checked by occupation value - return c.OCCUPIED_CELL if self.is_closed else 0.5 + return c.CLOSED_DOOR_CELL if self.is_closed else c.OPEN_DOOR_CELL @property def str_state(self): diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index e8392e6..f6078dd 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -460,7 +460,9 @@ class Agents(MovingEntityObjectRegister): class Doors(EntityRegister): - def __init__(self, *args, **kwargs): + def __init__(self, *args, have_area: bool = False, **kwargs): + self.have_area = have_area + self._area_marked = False super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs) _accepted_objects = Door @@ -475,6 +477,18 @@ class Doors(EntityRegister): for door in self: door.tick() + def as_array(self): + if self.have_area and not self._area_marked: + for door in self: + for pos in door.access_area: + if self._individual_slices: + pass + else: + pos = (0, *pos) + self._lazy_eval_transforms.append((pos, c.ACCESS_DOOR_CELL)) + self._area_marked = True + return super(Doors, self).as_array() + class Actions(ObjectRegister): _accepted_objects = Action diff --git a/environments/factory/factory_battery.py b/environments/factory/factory_battery.py index 6b114fd..c09cb10 100644 --- a/environments/factory/factory_battery.py +++ b/environments/factory/factory_battery.py @@ -155,7 +155,7 @@ class BatteryFactory(BaseFactory): if isinstance(btry_prop, dict): btry_prop = BatteryProperties(**btry_prop) if isinstance(rewards_dest, dict): - rewards_dest = RewardsBtry(**rewards_dest) + rewards_dest = BatteryProperties(**rewards_dest) self.btry_prop = btry_prop self.rewards_dest = rewards_dest super().__init__(*args, **kwargs) diff --git a/environments/factory/factory_dirt.py b/environments/factory/factory_dirt.py index e4e0e85..2e51983 100644 --- a/environments/factory/factory_dirt.py +++ b/environments/factory/factory_dirt.py @@ -1,4 +1,5 @@ import time +from pathlib import Path from typing import List, Union, NamedTuple, Dict import random @@ -284,7 +285,8 @@ if __name__ == '__main__': ) obs_props = ObservationProperties(render_agents=aro.COMBINED, omit_agent_self=True, - pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True) + pomdp_r=2, additional_agent_placeholder=None, cast_shadows=True, + indicate_door_area=True) move_props = {'allow_square_movement': True, 'allow_diagonal_movement': False, @@ -295,13 +297,15 @@ if __name__ == '__main__': factory = DirtFactory(n_agents=10, done_at_collision=False, level_name='rooms', max_steps=1000, - doors_have_area=False, + doors_have_area=True, obs_prop=obs_props, parse_doors=True, verbose=True, mv_prop=move_props, dirt_prop=dirt_props, # inject_agents=[TSPDirtAgent], ) + factory.save_params(Path('rewards_param')) + # noinspection DuplicatedCode n_actions = factory.action_space.n - 1 _ = factory.observation_space diff --git a/environments/helpers.py b/environments/helpers.py index 35d0935..954ff78 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -31,11 +31,15 @@ class Constants: FREE_CELL = 0 OCCUPIED_CELL = 1 SHADOWED_CELL = -1 + ACCESS_DOOR_CELL = 1/3 + OPEN_DOOR_CELL = 2/3 + CLOSED_DOOR_CELL = 3/3 NO_POS = (-9999, -9999) DOORS = 'Doors' CLOSED_DOOR = 'closed' OPEN_DOOR = 'open' + ACCESS_DOOR = 'access' ACTION = 'action' COLLISION = 'collision' @@ -87,7 +91,6 @@ class RewardsBase(NamedTuple): m = EnvActions c = Constants -r = RewardsBase ACTIONMAP = defaultdict(lambda: (0, 0), {m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1), diff --git a/environments/utility_classes.py b/environments/utility_classes.py index 8ef1a7b..3dfc123 100644 --- a/environments/utility_classes.py +++ b/environments/utility_classes.py @@ -23,6 +23,7 @@ class ObservationProperties(NamedTuple): cast_shadows: bool = True frames_to_stack: int = 0 pomdp_r: int = 0 + indicate_door_area: bool = True show_global_position_info: bool = False @@ -34,4 +35,3 @@ class MarlFrameStack(gym.ObservationWrapper): if isinstance(self.env, FrameStack) and self.env.unwrapped.n_agents > 1: return observation[0:].swapaxes(0, 1) return observation -