Logging Monitor Callback

This commit is contained in:
steffen-illium
2021-05-20 09:49:08 +02:00
parent c1cb7a4ffc
commit e7d31aa272
4 changed files with 83 additions and 30 deletions

View File

@ -56,6 +56,7 @@ class BaseFactory(gym.Env):
self.allow_vertical_movement = True
self.allow_horizontal_movement = True
self.allow_no_OP = True
self.done_at_collision = True
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
self.level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
@ -96,7 +97,7 @@ class BaseFactory(gym.Env):
self.steps += 1
# Move this in a seperate function?
states = list()
agent_states = list()
for agent_i, action in enumerate(actions):
agent_i_state = AgentState(agent_i, action)
if self._is_moving_action(action):
@ -107,13 +108,15 @@ class BaseFactory(gym.Env):
pos, valid = self.additional_actions(agent_i, action)
# Update state accordingly
agent_i_state.update(pos=pos, action_valid=valid)
states.append(agent_i_state)
agent_states.append(agent_i_state)
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
states[i].update(collision_vector=collision_vec)
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
agent_states[i].update(collision_vector=collision_vec)
if self.done_at_collision and collision_vec.any():
self.done = True
self.agent_states = states
reward, info = self.calculate_reward(states)
self.agent_states = agent_states
reward, info = self.calculate_reward(agent_states)
if self.steps >= self.max_steps:
self.done = True