Logging Monitor Callback
This commit is contained in:
@ -56,6 +56,7 @@ class BaseFactory(gym.Env):
|
||||
self.allow_vertical_movement = True
|
||||
self.allow_horizontal_movement = True
|
||||
self.allow_no_OP = True
|
||||
self.done_at_collision = True
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
@ -96,7 +97,7 @@ class BaseFactory(gym.Env):
|
||||
self.steps += 1
|
||||
|
||||
# Move this in a seperate function?
|
||||
states = list()
|
||||
agent_states = list()
|
||||
for agent_i, action in enumerate(actions):
|
||||
agent_i_state = AgentState(agent_i, action)
|
||||
if self._is_moving_action(action):
|
||||
@ -107,13 +108,15 @@ class BaseFactory(gym.Env):
|
||||
pos, valid = self.additional_actions(agent_i, action)
|
||||
# Update state accordingly
|
||||
agent_i_state.update(pos=pos, action_valid=valid)
|
||||
states.append(agent_i_state)
|
||||
agent_states.append(agent_i_state)
|
||||
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
|
||||
states[i].update(collision_vector=collision_vec)
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
|
||||
agent_states[i].update(collision_vector=collision_vec)
|
||||
if self.done_at_collision and collision_vec.any():
|
||||
self.done = True
|
||||
|
||||
self.agent_states = states
|
||||
reward, info = self.calculate_reward(states)
|
||||
self.agent_states = agent_states
|
||||
reward, info = self.calculate_reward(agent_states)
|
||||
|
||||
if self.steps >= self.max_steps:
|
||||
self.done = True
|
||||
|
Reference in New Issue
Block a user