From c7b4c69d171aed99630e07eb3b0c452afbe110ff Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Fri, 30 Jul 2021 08:40:39 +0200 Subject: [PATCH] Doors with area debugged --- environments/factory/base/base_factory.py | 11 +++++++---- environments/factory/base/objects.py | 11 ++++++++--- environments/factory/base/registers.py | 13 +++++++++---- environments/factory/simple_factory.py | 2 +- main.py | 2 +- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 2393296..e64e5ad 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -38,6 +38,8 @@ class BaseFactory(gym.Env): slices = self._slices.n - (self._agents.n - 1) elif not self.combin_agent_slices_in_obs and not self.omit_agent_slice_in_obs: slices = self._slices.n + else: + raise RuntimeError('This should not happen!') level_shape = (self.pomdp_r * 2 + 1, self.pomdp_r * 2 + 1) if self.pomdp_r else self._level_shape space = spaces.Box(low=0, high=1, shape=(slices, *level_shape), dtype=np.float32) @@ -168,7 +170,7 @@ class BaseFactory(gym.Env): # Door Init if self.parse_doors: tiles = [self._tiles.by_pos(x) for x in self._slices.by_enum(c.DOORS).occupied_tiles] - self._doors = Doors.from_tiles(tiles, context=self._tiles) + self._doors = Doors.from_tiles(tiles, context=self._tiles, has_area=self.doors_have_area) # Agent Init on random positions self._agents = Agents.from_tiles(np.random.choice(self._tiles, self.n_agents)) @@ -229,7 +231,7 @@ class BaseFactory(gym.Env): door = self._doors.get_near_position(agent.pos) else: door = self._doors.by_pos(agent.pos) - if door: + if door is not None: door.use() valid = c.VALID.value # When he doesn't... @@ -391,9 +393,10 @@ class BaseFactory(gym.Env): if self.parse_doors and agent.last_pos != h.NO_POS: if door := self._doors.by_pos(new_tile.pos): if door.can_collide: - pass - else: # door.is_closed: return agent.tile, c.NOT_VALID + else: # door.is_closed: + pass + if door := self._doors.by_pos(agent.pos): if door.is_open: pass diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index f5186e2..62767ec 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -192,7 +192,10 @@ class Door(Entity): @property def can_collide(self): - return False if self.is_open else True + if self.has_area: + return False if self.is_open else True + else: + return False @property def encoding(self): @@ -200,11 +203,13 @@ class Door(Entity): @property def access_area(self): - return [node for node in self.connectivity.nodes if node not in range(len(self.connectivity_subgroups))] + return [node for node in self.connectivity.nodes + if node not in range(len(self.connectivity_subgroups)) and node != self.pos] - def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10): + def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10, has_area=False): super(Door, self).__init__(*args) self._state = c.CLOSED_DOOR + self.has_area = has_area self.auto_close_interval = auto_close_interval self.time_to_close = -1 neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1] diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index 8ec439e..26ba575 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -99,9 +99,11 @@ class Register: class EntityRegister(Register): @classmethod - def from_argwhere_coordinates(cls, argwhere_coordinates): + def from_argwhere_coordinates(cls, argwhere_coordinates, **kwargs): tiles = cls() - tiles.register_additional_items([cls._accepted_objects(i, pos) for i, pos in enumerate(argwhere_coordinates)]) + tiles.register_additional_items( + [cls._accepted_objects(i, pos, **kwargs) for i, pos in enumerate(argwhere_coordinates)] + ) return tiles def __init__(self): @@ -164,8 +166,11 @@ class Agents(Register): class Doors(EntityRegister): _accepted_objects = Door - def get_near_position(self, position: (int, int)): - return [door for door in self if position in door.access_area][0] + def get_near_position(self, position: (int, int)) -> Union[None, Door]: + if found_doors := [door for door in self if position in door.access_area]: + return found_doors[0] + else: + return None def tick_doors(self): for door in self: diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index ef9371b..64302be 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -172,7 +172,7 @@ class SimpleFactory(BaseFactory): try: # penalty = current_dirt_amount - reward = dirt_distribution_score + reward = 0 except (ZeroDivisionError, RuntimeWarning): reward = 0 diff --git a/main.py b/main.py index 1ebecda..b1b5242 100644 --- a/main.py +++ b/main.py @@ -93,7 +93,7 @@ if __name__ == '__main__': # from sb3_contrib import QRDQN dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20, - max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05, + max_local_amount=1, spawn_frequency=10, max_spawn_ratio=0.05, dirt_smear_amount=0.0) move_props = MovementProperties(allow_diagonal_movement=True, allow_square_movement=True,