diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 0fbb306..8880dbc 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -45,19 +45,16 @@ class BaseFactory(object): collisions_vec = self.state[:, x, y].copy() # otherwise you overwrite the grid/state collisions_vec[i+1] = 0 # no self-collisions collision_vecs.append(collisions_vec) - self.handle_collisions(collisions_vec) - r += self.step_core(collisions_vec, actions, r) + reward, info = self.step_core(collisions_vec, actions, r) + r += reward if self.steps >= self.max_steps: self.done = True - return self.state, r, self.done, {} + return self.state, r, self.done, info def make_move(self, agent_i, old_pos, new_pos): (x, y), (x_new, y_new) = old_pos, new_pos self.state[agent_i+1, x, y] = 0 self.state[agent_i+1, x_new, y_new] = 1 - def handle_collisions(self, vecs): - pass - def step_core(self, collisions_vec, actions, r): - return 0 + return 0, {} diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index f809abc..ba4c77f 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -11,12 +11,16 @@ class SimpleFactory(BaseFactory): super().reset() dirt_slice = np.zeros((1, *self.state.shape[1:])) self.state = np.concatenate((self.state, dirt_slice)) # dirt is now the last slice + free_for_dirt = self.free_for_dirt() + for x, y in free_for_dirt[:self.max_dirt]: + self.state[-1, x, y] = 1 + print(self.state) + + def free_for_dirt(self): free_for_dirt = self.state.sum(0) free_for_dirt = np.argwhere(free_for_dirt == 0) np.random.shuffle(free_for_dirt) - for x,y in free_for_dirt[:self.max_dirt]: - self.state[-1, x, y] = 1 - print(self.state) + return free_for_dirt if __name__ == '__main__':