Doors with area debugged
This commit is contained in:
parent
042c850588
commit
c7b4c69d17
@ -38,6 +38,8 @@ class BaseFactory(gym.Env):
|
|||||||
slices = self._slices.n - (self._agents.n - 1)
|
slices = self._slices.n - (self._agents.n - 1)
|
||||||
elif not self.combin_agent_slices_in_obs and not self.omit_agent_slice_in_obs:
|
elif not self.combin_agent_slices_in_obs and not self.omit_agent_slice_in_obs:
|
||||||
slices = self._slices.n
|
slices = self._slices.n
|
||||||
|
else:
|
||||||
|
raise RuntimeError('This should not happen!')
|
||||||
|
|
||||||
level_shape = (self.pomdp_r * 2 + 1, self.pomdp_r * 2 + 1) if self.pomdp_r else self._level_shape
|
level_shape = (self.pomdp_r * 2 + 1, self.pomdp_r * 2 + 1) if self.pomdp_r else self._level_shape
|
||||||
space = spaces.Box(low=0, high=1, shape=(slices, *level_shape), dtype=np.float32)
|
space = spaces.Box(low=0, high=1, shape=(slices, *level_shape), dtype=np.float32)
|
||||||
@ -168,7 +170,7 @@ class BaseFactory(gym.Env):
|
|||||||
# Door Init
|
# Door Init
|
||||||
if self.parse_doors:
|
if self.parse_doors:
|
||||||
tiles = [self._tiles.by_pos(x) for x in self._slices.by_enum(c.DOORS).occupied_tiles]
|
tiles = [self._tiles.by_pos(x) for x in self._slices.by_enum(c.DOORS).occupied_tiles]
|
||||||
self._doors = Doors.from_tiles(tiles, context=self._tiles)
|
self._doors = Doors.from_tiles(tiles, context=self._tiles, has_area=self.doors_have_area)
|
||||||
|
|
||||||
# Agent Init on random positions
|
# Agent Init on random positions
|
||||||
self._agents = Agents.from_tiles(np.random.choice(self._tiles, self.n_agents))
|
self._agents = Agents.from_tiles(np.random.choice(self._tiles, self.n_agents))
|
||||||
@ -229,7 +231,7 @@ class BaseFactory(gym.Env):
|
|||||||
door = self._doors.get_near_position(agent.pos)
|
door = self._doors.get_near_position(agent.pos)
|
||||||
else:
|
else:
|
||||||
door = self._doors.by_pos(agent.pos)
|
door = self._doors.by_pos(agent.pos)
|
||||||
if door:
|
if door is not None:
|
||||||
door.use()
|
door.use()
|
||||||
valid = c.VALID.value
|
valid = c.VALID.value
|
||||||
# When he doesn't...
|
# When he doesn't...
|
||||||
@ -391,9 +393,10 @@ class BaseFactory(gym.Env):
|
|||||||
if self.parse_doors and agent.last_pos != h.NO_POS:
|
if self.parse_doors and agent.last_pos != h.NO_POS:
|
||||||
if door := self._doors.by_pos(new_tile.pos):
|
if door := self._doors.by_pos(new_tile.pos):
|
||||||
if door.can_collide:
|
if door.can_collide:
|
||||||
pass
|
|
||||||
else: # door.is_closed:
|
|
||||||
return agent.tile, c.NOT_VALID
|
return agent.tile, c.NOT_VALID
|
||||||
|
else: # door.is_closed:
|
||||||
|
pass
|
||||||
|
|
||||||
if door := self._doors.by_pos(agent.pos):
|
if door := self._doors.by_pos(agent.pos):
|
||||||
if door.is_open:
|
if door.is_open:
|
||||||
pass
|
pass
|
||||||
|
@ -192,7 +192,10 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def can_collide(self):
|
||||||
return False if self.is_open else True
|
if self.has_area:
|
||||||
|
return False if self.is_open else True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@ -200,11 +203,13 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def access_area(self):
|
def access_area(self):
|
||||||
return [node for node in self.connectivity.nodes if node not in range(len(self.connectivity_subgroups))]
|
return [node for node in self.connectivity.nodes
|
||||||
|
if node not in range(len(self.connectivity_subgroups)) and node != self.pos]
|
||||||
|
|
||||||
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
|
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10, has_area=False):
|
||||||
super(Door, self).__init__(*args)
|
super(Door, self).__init__(*args)
|
||||||
self._state = c.CLOSED_DOOR
|
self._state = c.CLOSED_DOOR
|
||||||
|
self.has_area = has_area
|
||||||
self.auto_close_interval = auto_close_interval
|
self.auto_close_interval = auto_close_interval
|
||||||
self.time_to_close = -1
|
self.time_to_close = -1
|
||||||
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
neighbor_pos = list(itertools.product([-1, 1, 0], repeat=2))[:-1]
|
||||||
|
@ -99,9 +99,11 @@ class Register:
|
|||||||
class EntityRegister(Register):
|
class EntityRegister(Register):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_argwhere_coordinates(cls, argwhere_coordinates):
|
def from_argwhere_coordinates(cls, argwhere_coordinates, **kwargs):
|
||||||
tiles = cls()
|
tiles = cls()
|
||||||
tiles.register_additional_items([cls._accepted_objects(i, pos) for i, pos in enumerate(argwhere_coordinates)])
|
tiles.register_additional_items(
|
||||||
|
[cls._accepted_objects(i, pos, **kwargs) for i, pos in enumerate(argwhere_coordinates)]
|
||||||
|
)
|
||||||
return tiles
|
return tiles
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -164,8 +166,11 @@ class Agents(Register):
|
|||||||
class Doors(EntityRegister):
|
class Doors(EntityRegister):
|
||||||
_accepted_objects = Door
|
_accepted_objects = Door
|
||||||
|
|
||||||
def get_near_position(self, position: (int, int)):
|
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||||
return [door for door in self if position in door.access_area][0]
|
if found_doors := [door for door in self if position in door.access_area]:
|
||||||
|
return found_doors[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def tick_doors(self):
|
def tick_doors(self):
|
||||||
for door in self:
|
for door in self:
|
||||||
|
@ -172,7 +172,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# penalty = current_dirt_amount
|
# penalty = current_dirt_amount
|
||||||
reward = dirt_distribution_score
|
reward = 0
|
||||||
except (ZeroDivisionError, RuntimeWarning):
|
except (ZeroDivisionError, RuntimeWarning):
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
|
2
main.py
2
main.py
@ -93,7 +93,7 @@ if __name__ == '__main__':
|
|||||||
# from sb3_contrib import QRDQN
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
dirt_props = DirtProperties(clean_amount=1, gain_amount=0.1, max_global_amount=20,
|
||||||
max_local_amount=1, spawn_frequency=5, max_spawn_ratio=0.05,
|
max_local_amount=1, spawn_frequency=10, max_spawn_ratio=0.05,
|
||||||
dirt_smear_amount=0.0)
|
dirt_smear_amount=0.0)
|
||||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user