diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 139af54..2f44844 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -160,9 +160,7 @@ class BaseFactory: excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices] state = self.state[[x for x in inds if x not in excluded_slices]] - free_cells = state.sum(0) - free_cells[excluded_slices] = 0 - free_cells = np.argwhere(free_cells == h.IS_FREE_CELL) + free_cells = np.argwhere(state.sum(0) == h.IS_FREE_CELL) np.random.shuffle(free_cells) return free_cells diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py index 5041502..8212821 100644 --- a/environments/factory/simple_factory_getting_dirty.py +++ b/environments/factory/simple_factory_getting_dirty.py @@ -90,13 +90,9 @@ class GettingDirty(BaseFactory): return self.state, r, self.done, {} def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): + # TODO: What reward to use? this_step_reward = 0 - dirt_vs_level_collisions = np.argwhere(self.state[h.LEVEL_IDX] * self.state[DIRT_INDEX] == h.IS_OCCUPIED_CELL) - for dirt_vs_level_collision in dirt_vs_level_collisions: - print(f'Dirt was placed on Level at: {dirt_vs_level_collision.squeeze()}') - pass - for agent_state in agent_states: collisions = agent_state.collisions print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '