Added explanation for narrow_corridor.yaml

This commit is contained in:
Steffen Illium
2023-11-10 06:54:38 +01:00
parent 06a5130b25
commit a9462a8b6f
3 changed files with 40 additions and 6 deletions
marl_factory_grid

@ -126,8 +126,10 @@ class AssignGlobalPositions(Rule):
class WatchCollisions(Rule):
def __init__(self, done_at_collisions: bool = False):
def __init__(self, reward=r.COLLISION, done_at_collisions: bool = False, reward_at_done=r.COLLISION_DONE):
super().__init__()
self.reward_at_done = reward_at_done
self.reward = reward
self.done_at_collisions = done_at_collisions
self.curr_done = False
@ -140,12 +142,12 @@ class WatchCollisions(Rule):
if len(guests) >= 2:
for i, guest in enumerate(guests):
try:
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
guest.set_state(TickResult(identifier=c.COLLISION, reward=self.reward,
validity=c.NOT_VALID, entity=self))
except AttributeError:
pass
results.append(TickResult(entity=guest, identifier=c.COLLISION,
reward=r.COLLISION, validity=c.VALID))
reward=self.reward, validity=c.VALID))
self.curr_done = True if self.done_at_collisions else False
return results
@ -154,5 +156,5 @@ class WatchCollisions(Rule):
inter_entity_collision_detected = self.curr_done
move_failed = any(h.is_move(x.state.identifier) and not x.state.validity for x in state[c.AGENT])
if inter_entity_collision_detected or move_failed:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)]
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
return []