Agent Trained on Doors
This commit is contained in:
@ -22,10 +22,11 @@ class BaseFactory(gym.Env):
|
||||
@property
|
||||
def observation_space(self):
|
||||
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
||||
agent_slice = 1 if self.combin_agent_slices_in_obs else agent_slice
|
||||
agent_slice = (self.n_agents - 1) if self.combin_agent_slices_in_obs else agent_slice
|
||||
if self.pomdp_radius:
|
||||
return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1,
|
||||
self.pomdp_radius * 2 + 1), dtype=np.float32)
|
||||
shape = (self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
return space
|
||||
else:
|
||||
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._state.shape)]
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
@ -194,6 +195,14 @@ class BaseFactory(gym.Env):
|
||||
if self.done_at_collision and collision_vec.any():
|
||||
done = True
|
||||
|
||||
# Step the door close intervall
|
||||
agents_pos = [agent.pos for agent in self._agent_states]
|
||||
for door_i, door in enumerate(self._door_states):
|
||||
if door.is_open and door.time_to_close and door.pos not in agents_pos:
|
||||
door.time_to_close -= 1
|
||||
elif door.is_open and not door.time_to_close and door.pos not in agents_pos:
|
||||
door.use()
|
||||
|
||||
reward, info = self.calculate_reward(self._agent_states)
|
||||
|
||||
if self._steps >= self.max_steps:
|
||||
@ -256,7 +265,7 @@ class BaseFactory(gym.Env):
|
||||
x_new = x + x_diff
|
||||
y_new = y + y_diff
|
||||
|
||||
if h.DOORS in self._state_slices.values():
|
||||
if h.DOORS in self._state_slices.values() and self._agent_states[agent_i]._last_pos != (-1, -1):
|
||||
door = [door for door in self._door_states if door.pos == (x, y)]
|
||||
if door:
|
||||
door = door[0]
|
||||
@ -326,7 +335,7 @@ class BaseFactory(gym.Env):
|
||||
# Returns: Reward, Info
|
||||
raise NotImplementedError
|
||||
|
||||
def render(self):
|
||||
def render(self, mode='human'):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
|
Reference in New Issue
Block a user