mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-06 15:40:37 +01:00
Changed Collision Check
This commit is contained in:
@@ -200,6 +200,11 @@ class WatchCollisions(Rule):
|
||||
self.curr_done = False
|
||||
pos_with_collisions = state.get_collision_positions()
|
||||
results = list()
|
||||
for agent in state[c.AGENT]:
|
||||
a_s = agent.state
|
||||
if h.is_move(a_s.identifier) and a_s.action_introduced_collision:
|
||||
results.append(TickResult(entity=agent, identifier=c.COLLISION,
|
||||
reward=self.reward, validity=c.VALID))
|
||||
for pos in pos_with_collisions:
|
||||
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
|
||||
if len(guests) >= 2:
|
||||
@@ -210,15 +215,18 @@ class WatchCollisions(Rule):
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=self.reward, validity=c.VALID))
|
||||
if not any([x.entity == guest for x in results]):
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=self.reward, validity=c.VALID))
|
||||
self.curr_done = True if self.done_at_collisions else False
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if self.done_at_collisions:
|
||||
inter_entity_collision_detected = self.curr_done
|
||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||
if inter_entity_collision_detected or move_failed:
|
||||
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
|
||||
for x in state[c.AGENT]
|
||||
)
|
||||
if inter_entity_collision_detected or collision_in_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user