Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -68,6 +68,7 @@ class BaseFactory(gym.Env):
|
||||
self.steps = 0
|
||||
self.cumulative_reward = 0
|
||||
self.monitor = FactoryMonitor(self)
|
||||
self.agent_states = []
|
||||
# Agent placement ...
|
||||
agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8)
|
||||
floor_tiles = np.argwhere(self.level == h.IS_FREE_CELL)
|
||||
@@ -75,9 +76,13 @@ class BaseFactory(gym.Env):
|
||||
np.random.shuffle(floor_tiles)
|
||||
for i, (x, y) in enumerate(floor_tiles[:self.n_agents]):
|
||||
agents[i, x, y] = h.IS_OCCUPIED_CELL
|
||||
agent_state = AgentState(i, -1)
|
||||
agent_state.update(pos=[x, y])
|
||||
self.agent_states.append(agent_state)
|
||||
# state.shape = level, agent 1,..., agent n,
|
||||
self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0)
|
||||
# Returns State, Reward, Done, Info
|
||||
|
||||
return self.state, 0, self.done, {}
|
||||
|
||||
def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
|
||||
@@ -103,6 +108,7 @@ class BaseFactory(gym.Env):
|
||||
for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
|
||||
states[i].update(collision_vector=collision_vec)
|
||||
|
||||
self.agent_states = states
|
||||
reward, info = self.calculate_reward(states)
|
||||
self.cumulative_reward += reward
|
||||
|
||||
|
||||
Reference in New Issue
Block a user