Changed Collision Check

This commit is contained in:
Steffen Illium
2023-11-23 11:07:16 +01:00
parent b68f9e1911
commit 142bbb2c0c
4 changed files with 30 additions and 14 deletions

View File

@@ -200,6 +200,11 @@ class WatchCollisions(Rule):
self.curr_done = False
pos_with_collisions = state.get_collision_positions()
results = list()
for agent in state[c.AGENT]:
a_s = agent.state
if h.is_move(a_s.identifier) and a_s.action_introduced_collision:
results.append(TickResult(entity=agent, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID))
for pos in pos_with_collisions:
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]
if len(guests) >= 2:
@@ -210,15 +215,18 @@ class WatchCollisions(Rule):
)
except AttributeError:
pass
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID))
if not any([x.entity == guest for x in results]):
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=self.reward, validity=c.VALID))
self.curr_done = True if self.done_at_collisions else False
return results
def on_check_done(self, state) -> List[DoneResult]:
if self.done_at_collisions:
inter_entity_collision_detected = self.curr_done
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
collision_in_step = any(h.is_move(x.state.identifier) and x.state.action_introduced_collision
for x in state[c.AGENT]
)
if inter_entity_collision_detected or collision_in_step:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
return []