Doors with area debugged

This commit is contained in:
Steffen Illium 2021-07-30 08:40:39 +02:00
parent 042c850588
commit c7b4c69d17
5 changed files with 26 additions and 13 deletions

View File

@ -38,6 +38,8 @@ class BaseFactory(gym.Env):
slices = self._slices.n - (self._agents.n - 1)
elif not self.combin_agent_slices_in_obs and not self.omit_agent_slice_in_obs:
slices = self._slices.n
else:
raise RuntimeError('This should not happen!')
level_shape = (self.pomdp_r * 2 + 1, self.pomdp_r * 2 + 1) if self.pomdp_r else self._level_shape
space = spaces.Box(low=0, high=1, shape=(slices, *level_shape), dtype=np.float32)
@ -168,7 +170,7 @@ class BaseFactory(gym.Env):
# Door Init
if self.parse_doors:
tiles = [self._tiles.by_pos(x) for x in self._slices.by_enum(c.DOORS).occupied_tiles]
self._doors = Doors.from_tiles(tiles, context=self._tiles)
self._doors = Doors.from_tiles(tiles, context=self._tiles, has_area=self.doors_have_area)
# Agent Init on random positions
self._agents = Agents.from_tiles(np.random.choice(self._tiles, self.n_agents))
@ -229,7 +231,7 @@ class BaseFactory(gym.Env):
door = self._doors.get_near_position(agent.pos)
else:
door = self._doors.by_pos(agent.pos)
if door:
if door is not None:
door.use()
valid = c.VALID.value
# When he doesn't...
@ -391,9 +393,10 @@ class BaseFactory(gym.Env):
if self.parse_doors and agent.last_pos != h.NO_POS:
if door := self._doors.by_pos(new_tile.pos):
if door.can_collide:
pass
else: # door.is_closed:
return agent.tile, c.NOT_VALID
else: # door.is_closed:
pass
if door := self._doors.by_pos(agent.pos):
if door.is_open:
pass

View File

@ -192,7 +192,10 @@ class Door(Entity):
@property
def can_collide(self):
return False if self.is_open else True
if self.has_area:
return False if self.is_open else True
else:
return False
@property
def encoding(self):
@ -200,11 +203,13 @@ class Door(Entity):
@property
def access_area(self):
return [node for node in self.connectivity.nodes if node not in range(len(self.connectivity_subgroups))]
return [node for node in self.connectivity.nodes
if node not in range(len(self.connectivity_subgroups)) and node != self.pos]
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10, has_area=False):
super(Door, self).__init__(*args)
self._state = c.CLOSED_DOOR
self.has_area = has_area
self.auto_close_interval = auto_close_interval
self.time_to_close = -1
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]

View File

@ -99,9 +99,11 @@ class Register:
class EntityRegister(Register):
@classmethod
def from_argwhere_coordinates(cls, argwhere_coordinates):
def from_argwhere_coordinates(cls, argwhere_coordinates, **kwargs):
tiles = cls()
tiles.register_additional_items([cls._accepted_objects(i, pos) for i, pos in enumerate(argwhere_coordinates)])
tiles.register_additional_items(
[cls._accepted_objects(i, pos, **kwargs) for i, pos in enumerate(argwhere_coordinates)]
)
return tiles
def __init__(self):
@ -164,8 +166,11 @@ class Agents(Register):
class Doors(EntityRegister):
_accepted_objects = Door
def get_near_position(self, position: (int, int)):
return [door for door in self if position in door.access_area][0]
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
if found_doors := [door for door in self if position in door.access_area]:
return found_doors[0]
else:
return None
def tick_doors(self):
for door in self:

View File

@ -172,7 +172,7 @@ class SimpleFactory(BaseFactory):
try:
# penalty = current_dirt_amount
reward = dirt_distribution_score
reward = 0
except (ZeroDivisionError, RuntimeWarning):
reward = 0

View File

@ -93,7 +93,7 @@ if __name__ == '__main__':
# from sb3_contrib import QRDQN
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
max_local_amount=1, spawn_frequency=10, max_spawn_ratio=0.05,
dirt_smear_amount=0.0)
move_props = MovementProperties(allow_diagonal_movement=True,
allow_square_movement=True,