diff --git a/environments/factory/factory_cleaning.py b/environments/factory/factory_cleaning.py index 158595d..f01b769 100644 --- a/environments/factory/factory_cleaning.py +++ b/environments/factory/factory_cleaning.py @@ -16,7 +16,7 @@ class Factory(object): def reset(self): self.done = False - self.agents = np.zeros((self.n_agents, *self.level.shape)) + self.agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8) free_cells = np.argwhere(self.level == 0) np.random.shuffle(free_cells) for i in range(self.n_agents): @@ -24,17 +24,39 @@ class Factory(object): self.agents[i, r, c] = 1 free_cells = free_cells[self.n_agents:] self.state = np.concatenate((self.level[np.newaxis, ...], self.agents), 0) + return self.state, 0, self.done, {} def step(self, actions): assert type(actions) in [int, list] if type(actions) == int: actions = [actions] + r = 0 # level, agent 1,..., agent n, for i, a in enumerate(actions): old_pos, new_pos, valid = h.check_agent_move(state=self.state, dim=i+1, action=a) - print(old_pos, new_pos, valid) + if valid: + self.make_move(i, old_pos, new_pos) + collision_vecs = [] + for i in range(self.n_agents): # might as well save the positions (redundant) + agent_slice = self.state[i+1] + x, y = np.argwhere(agent_slice == 1)[0] + collisions_vec = self.state[:, x, y] + collisions_vec[i+1] = 0 # no self-collisions + collision_vecs.append(collisions_vec) + self.handle_collisions(collisions_vec) + + return self.state, r, self.done, {} + + def make_move(self, agent_i, old_pos, new_pos): + (x, y), (x_new, y_new) = old_pos, new_pos + self.state[agent_i, x, y] = 0 + self.state[agent_i, x_new, y_new] = 1 + + def handle_collisions(self, vecs): + pass if __name__ == '__main__': factory = Factory(n_agents=1) - factory.step(0) \ No newline at end of file + factory.step(0) + print(factory.state.shape) \ No newline at end of file diff --git a/environments/helpers.py b/environments/helpers.py index bbb839a..641d228 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -16,7 +16,7 @@ def parse_level(path): def one_hot_level(level, wall_char=WALL): grid = np.array(level) - binary_grid = np.zeros(grid.shape) + binary_grid = np.zeros(grid.shape, dtype=np.int8) binary_grid[grid == wall_char] = 1 return binary_grid