This commit is contained in:
Steffen Illium 2021-07-29 11:02:12 +02:00
parent 8631f11502
commit 042c850588
4 changed files with 22 additions and 5 deletions

View File

@ -91,7 +91,7 @@ class BaseFactory(gym.Env):
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
combin_agent_slices_in_obs: bool = False, frames_to_stack=0, record_episodes=False,
omit_agent_slice_in_obs=False, done_at_collision=False, cast_shadows=True,
verbose=False, **kwargs):
verbose=False, doors_have_area=True, **kwargs):
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
# Attribute Assignment
@ -111,6 +111,7 @@ class BaseFactory(gym.Env):
self.done_at_collision = done_at_collision
self.record_episodes = record_episodes
self.parse_doors = parse_doors
self.doors_have_area = doors_have_area
# Actions
self._actions = Actions(self.movement_properties, can_use_doors=self.parse_doors)
@ -223,8 +224,12 @@ class BaseFactory(gym.Env):
elif self._actions.is_no_op(action):
valid = c.VALID.value
elif self._actions.is_door_usage(action):
# Check if agent raly stands on a door:
if door := self._doors.by_pos(agent.pos):
# Check if agent really is standing on a door:
if self.doors_have_area:
door = self._doors.get_near_position(agent.pos)
else:
door = self._doors.by_pos(agent.pos)
if door:
door.use()
valid = c.VALID.value
# When he doesn't...
@ -384,6 +389,11 @@ class BaseFactory(gym.Env):
return tile, valid
if self.parse_doors and agent.last_pos != h.NO_POS:
if door := self._doors.by_pos(new_tile.pos):
if door.can_collide:
pass
else: # door.is_closed:
return agent.tile, c.NOT_VALID
if door := self._doors.by_pos(agent.pos):
if door.is_open:
pass

View File

@ -192,12 +192,16 @@ class Door(Entity):
@property
def can_collide(self):
return False
return False if self.is_open else True
@property
def encoding(self):
return 1 if self.is_closed else -1
@property
def access_area(self):
return [node for node in self.connectivity.nodes if node not in range(len(self.connectivity_subgroups))]
def __init__(self, *args, context, closed_on_init=True, auto_close_interval=10):
super(Door, self).__init__(*args)
self._state = c.CLOSED_DOOR

View File

@ -164,6 +164,9 @@ class Agents(Register):
class Doors(EntityRegister):
_accepted_objects = Door
def get_near_position(self, position: (int, int)):
return [door for door in self if position in door.access_area][0]
def tick_doors(self):
for door in self:
door.tick()

View File

@ -236,7 +236,7 @@ if __name__ == '__main__':
allow_no_op=False)
factory = SimpleFactory(movement_properties=move_props, dirt_properties=dirt_props, n_agents=1,
combin_agent_slices_in_obs=False, level_name='rooms', parse_doors=True,
pomdp_radius=2, cast_shadows=True)
doors_have_area=True, pomdp_radius=2, cast_shadows=True)
n_actions = factory.action_space.n - 1
_ = factory.observation_space