Agent Trained on Doors

This commit is contained in:
steffen-illium
2021-06-17 14:27:18 +02:00
parent 26d7705e19
commit d9d8784338
10 changed files with 125 additions and 28 deletions

View File

@ -22,10 +22,11 @@ class BaseFactory(gym.Env):
@property
def observation_space(self):
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
agent_slice = 1 if self.combin_agent_slices_in_obs else agent_slice
agent_slice = (self.n_agents - 1) if self.combin_agent_slices_in_obs else agent_slice
if self.pomdp_radius:
return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1,
self.pomdp_radius * 2 + 1), dtype=np.float32)
shape = (self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1)
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
return space
else:
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._state.shape)]
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
@ -194,6 +195,14 @@ class BaseFactory(gym.Env):
if self.done_at_collision and collision_vec.any():
done = True
# Step the door close intervall
agents_pos = [agent.pos for agent in self._agent_states]
for door_i, door in enumerate(self._door_states):
if door.is_open and door.time_to_close and door.pos not in agents_pos:
door.time_to_close -= 1
elif door.is_open and not door.time_to_close and door.pos not in agents_pos:
door.use()
reward, info = self.calculate_reward(self._agent_states)
if self._steps >= self.max_steps:
@ -256,7 +265,7 @@ class BaseFactory(gym.Env):
x_new = x + x_diff
y_new = y + y_diff
if h.DOORS in self._state_slices.values():
if h.DOORS in self._state_slices.values() and self._agent_states[agent_i]._last_pos != (-1, -1):
door = [door for door in self._door_states if door.pos == (x, y)]
if door:
door = door[0]
@ -326,7 +335,7 @@ class BaseFactory(gym.Env):
# Returns: Reward, Info
raise NotImplementedError
def render(self):
def render(self, mode='human'):
raise NotImplementedError
def save_params(self, filepath: Path):