mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
is debugged, no longer dirty at [2, 8]
This commit is contained in:
parent
8769bc8d7b
commit
2acf91b395
@ -160,9 +160,7 @@ class BaseFactory:
|
|||||||
excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices]
|
excluded_slices = [inds[x] if x < 0 else x for x in excluded_slices]
|
||||||
state = self.state[[x for x in inds if x not in excluded_slices]]
|
state = self.state[[x for x in inds if x not in excluded_slices]]
|
||||||
|
|
||||||
free_cells = state.sum(0)
|
free_cells = np.argwhere(state.sum(0) == h.IS_FREE_CELL)
|
||||||
free_cells[excluded_slices] = 0
|
|
||||||
free_cells = np.argwhere(free_cells == h.IS_FREE_CELL)
|
|
||||||
np.random.shuffle(free_cells)
|
np.random.shuffle(free_cells)
|
||||||
return free_cells
|
return free_cells
|
||||||
|
|
||||||
|
@ -90,13 +90,9 @@ class GettingDirty(BaseFactory):
|
|||||||
return self.state, r, self.done, {}
|
return self.state, r, self.done, {}
|
||||||
|
|
||||||
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
|
||||||
|
# TODO: What reward to use?
|
||||||
this_step_reward = 0
|
this_step_reward = 0
|
||||||
|
|
||||||
dirt_vs_level_collisions = np.argwhere(self.state[h.LEVEL_IDX] * self.state[DIRT_INDEX] == h.IS_OCCUPIED_CELL)
|
|
||||||
for dirt_vs_level_collision in dirt_vs_level_collisions:
|
|
||||||
print(f'Dirt was placed on Level at: {dirt_vs_level_collision.squeeze()}')
|
|
||||||
pass
|
|
||||||
|
|
||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
collisions = agent_state.collisions
|
collisions = agent_state.collisions
|
||||||
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
|
||||||
|
Loading…
x
Reference in New Issue
Block a user