Doors with area debugged
This commit is contained in:
parent
042c850588
commit
c7b4c69d17
@ -38,6 +38,8 @@ class BaseFactory(gym.Env):
|
||||
slices = self._slices.n - (self._agents.n - 1)
|
||||
elif not self.combin_agent_slices_in_obs and not self.omit_agent_slice_in_obs:
|
||||
slices = self._slices.n
|
||||
else:
|
||||
raise RuntimeError('This should not happen!')
|
||||
|
||||
level_shape = (self.pomdp_r * 2 + 1, self.pomdp_r * 2 + 1) if self.pomdp_r else self._level_shape
|
||||
space = spaces.Box(low=0, high=1, shape=(slices, *level_shape), dtype=np.float32)
|
||||
@ -168,7 +170,7 @@ class BaseFactory(gym.Env):
|
||||
# Door Init
|
||||
if self.parse_doors:
|
||||
tiles = [self._tiles.by_pos(x) for x in self._slices.by_enum(c.DOORS).occupied_tiles]
|
||||
self._doors = Doors.from_tiles(tiles, context=self._tiles)
|
||||
self._doors = Doors.from_tiles(tiles, context=self._tiles, has_area=self.doors_have_area)
|
||||
|
||||
# Agent Init on random positions
|
||||
self._agents = Agents.from_tiles(np.random.choice(self._tiles, self.n_agents))
|
||||
@ -229,7 +231,7 @@ class BaseFactory(gym.Env):
|
||||
door = self._doors.get_near_position(agent.pos)
|
||||
else:
|
||||
door = self._doors.by_pos(agent.pos)
|
||||
if door:
|
||||
if door is not None:
|
||||
door.use()
|
||||
valid = c.VALID.value
|
||||
# When he doesn't...
|
||||
@ -391,9 +393,10 @@ class BaseFactory(gym.Env):
|
||||
if self.parse_doors and agent.last_pos != h.NO_POS:
|
||||
if door := self._doors.by_pos(new_tile.pos):
|
||||
if door.can_collide:
|
||||
pass
|
||||
else: # door.is_closed:
|
||||
return agent.tile, c.NOT_VALID
|
||||
else: # door.is_closed:
|
||||
pass
|
||||
|
||||
if door := self._doors.by_pos(agent.pos):
|
||||
if door.is_open:
|
||||
pass
|
||||
|
@ -192,7 +192,10 @@ class Door(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False if self.is_open else True
|
||||
if self.has_area:
|
||||
return False if self.is_open else True
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
@ -200,11 +203,13 @@ class Door(Entity):
|
||||
|
||||
@property
|
||||
def access_area(self):
|
||||
return [node for node in self.connectivity.nodes if node not in range(len(self.connectivity_subgroups))]
|
||||
return [node for node in self.connectivity.nodes
|
||||
if node not in range(len(self.connectivity_subgroups)) and node != self.pos]
|
||||
|
||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
|
||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10, has_area=False):
|
||||
super(Door, self).__init__(*args)
|
||||
self._state = c.CLOSED_DOOR
|
||||
self.has_area = has_area
|
||||
self.auto_close_interval = auto_close_interval
|
||||
self.time_to_close = -1
|
||||
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
||||
|
@ -99,9 +99,11 @@ class Register:
|
||||
class EntityRegister(Register):
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates):
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, **kwargs):
|
||||
tiles = cls()
|
||||
tiles.register_additional_items([cls._accepted_objects(i, pos) for i, pos in enumerate(argwhere_coordinates)])
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(i, pos, **kwargs) for i, pos in enumerate(argwhere_coordinates)]
|
||||
)
|
||||
return tiles
|
||||
|
||||
def __init__(self):
|
||||
@ -164,8 +166,11 @@ class Agents(Register):
|
||||
class Doors(EntityRegister):
|
||||
_accepted_objects = Door
|
||||
|
||||
def get_near_position(self, position: (int, int)):
|
||||
return [door for door in self if position in door.access_area][0]
|
||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||
if found_doors := [door for door in self if position in door.access_area]:
|
||||
return found_doors[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def tick_doors(self):
|
||||
for door in self:
|
||||
|
@ -172,7 +172,7 @@ class SimpleFactory(BaseFactory):
|
||||
|
||||
try:
|
||||
# penalty = current_dirt_amount
|
||||
reward = dirt_distribution_score
|
||||
reward = 0
|
||||
except (ZeroDivisionError, RuntimeWarning):
|
||||
reward = 0
|
||||
|
||||
|
2
main.py
2
main.py
@ -93,7 +93,7 @@ if __name__ == '__main__':
|
||||
# from sb3_contrib import QRDQN
|
||||
|
||||
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
||||
max_local_amount=1, spawn_frequency=10, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user