mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Added explanation for narrow_corridor.yaml
This commit is contained in:
marl_factory_grid
@ -126,8 +126,10 @@ class AssignGlobalPositions(Rule):
|
||||
|
||||
class WatchCollisions(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
|
||||
super().__init__()
|
||||
self.reward_at_done = reward_at_done
|
||||
self.reward = reward
|
||||
self.done_at_collisions = done_at_collisions
|
||||
self.curr_done = False
|
||||
|
||||
@ -140,12 +142,12 @@ class WatchCollisions(Rule):
|
||||
if len(guests) >= 2:
|
||||
for i, guest in enumerate(guests):
|
||||
try:
|
||||
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
|
||||
guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward,
|
||||
validity=c.NOT_VALID, entity=self))
|
||||
except AttributeError:
|
||||
pass
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=r.COLLISION, validity=c.VALID))
|
||||
reward=self.reward, validity=c.VALID))
|
||||
self.curr_done = True if self.done_at_collisions else False
|
||||
return results
|
||||
|
||||
@ -154,5 +156,5 @@ class WatchCollisions(Rule):
|
||||
inter_entity_collision_detected = self.curr_done
|
||||
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
|
||||
if inter_entity_collision_detected or move_failed:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
|
||||
return []
|
||||
|
Reference in New Issue
Block a user