diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index 55c66a9..3327ee1 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -53,6 +53,20 @@ class DestinationDone(Rule): return [] +class DoneOnReach(Rule): + + def __init__(self): + super(DoneOnReach, self).__init__() + + def on_check_done(self, state) -> List[DoneResult]: + dests = [x.pos for x in state[d.DESTINATION]] + agents = [x.pos for x in state[c.AGENT]] + + if any([x in dests for x in agents]): + return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] + return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] + + class DestinationSpawn(Rule): def __init__(self, spawn_frequency: int = 5, n_dests: int = 1,