mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
is debugged, no longer dirty at [2, 8]
This commit is contained in:
parent
8769bc8d7b
commit
2acf91b395
@ -160,9 +160,7 @@ class BaseFactory:
|
||||
excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices]
|
||||
state = self.state[[x for x in inds if x not in excluded_slices]]
|
||||
|
||||
free_cells = state.sum(0)
|
||||
free_cells[excluded_slices] = 0
|
||||
free_cells = np.argwhere(free_cells == h.IS_FREE_CELL)
|
||||
free_cells = np.argwhere(state.sum(0) == h.IS_FREE_CELL)
|
||||
np.random.shuffle(free_cells)
|
||||
return free_cells
|
||||
|
||||
|
@ -90,13 +90,9 @@ class GettingDirty(BaseFactory):
|
||||
return self.state, r, self.done, {}
|
||||
|
||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||
# TODO: What reward to use?
|
||||
this_step_reward = 0
|
||||
|
||||
dirt_vs_level_collisions = np.argwhere(self.state[h.LEVEL_IDX] * self.state[DIRT_INDEX] == h.IS_OCCUPIED_CELL)
|
||||
for dirt_vs_level_collision in dirt_vs_level_collisions:
|
||||
print(f'Dirt was placed on Level at: {dirt_vs_level_collision.squeeze()}')
|
||||
pass
|
||||
|
||||
for agent_state in agent_states:
|
||||
collisions = agent_state.collisions
|
||||
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||
|
Loading…
x
Reference in New Issue
Block a user