Multiple Fixes:

- Config Explainer
 - Rewards
 - Destination Reach Condition
 - Additional Step Callback
This commit is contained in:
Steffen Illium
2023-11-24 14:43:49 +01:00
parent 0ec260f6a2
commit 803d0dae7f
15 changed files with 158 additions and 143 deletions

View File

@ -33,8 +33,8 @@ class DestinationReachReward(Rule):
def tick_step(self, state) -> List[TickResult]:
results = []
reached = False
for dest in state[d.DESTINATION]:
reached = False
if dest.has_just_been_reached(state) and not dest.was_reached():
# Dest has just been reached, some agent needs to stand here
for agent in state[c.AGENT].by_pos(dest.pos):
@ -67,32 +67,27 @@ class DoneAtDestinationReach(DestinationReachReward):
"""
super().__init__(**kwargs)
self.condition = condition
self.reward = reward_at_done
self.reward_at_done = reward_at_done
assert condition in CONDITIONS
def on_check_done(self, state) -> List[DoneResult]:
if self.condition == ANY:
if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
elif self.condition == ALL:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
elif self.condition == SIMULTANEOUS:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
else:
for dest in state[d.DESTINATION]:
if dest.was_reached():
for agent in state[c.AGENT].by_pos(dest.pos):
if dest.bound_entity:
if dest.bound_entity == agent:
pass
else:
dest.unmark_as_reached()
return [DoneResult(f'{dest}_unmarked_as_reached',
validity=c.NOT_VALID, entity=dest)]
else:
pass
dest.unmark_as_reached()
state.print(f'{dest} unmarked as reached, not all targets are reached in parallel.')
else:
pass
return [DoneResult(f'all_unmarked_as_reached', validity=c.NOT_VALID)]
else:
raise ValueError('Check spelling of Parameter "condition".')