mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 01:21:36 +02:00
Merge branch 'main' into unit_testing
This commit is contained in:
@ -32,8 +32,8 @@ class DestinationReachReward(Rule):
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
results = []
|
||||
reached = False
|
||||
for dest in state[d.DESTINATION]:
|
||||
reached = False
|
||||
if dest.has_just_been_reached(state) and not dest.was_reached():
|
||||
# Dest has just been reached, some agent needs to stand here
|
||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||
@ -66,32 +66,27 @@ class DoneAtDestinationReach(DestinationReachReward):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.condition = condition
|
||||
self.reward = reward_at_done
|
||||
self.reward_at_done = reward_at_done
|
||||
assert condition in CONDITIONS
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if self.condition == ANY:
|
||||
if any(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
|
||||
elif self.condition == ALL:
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
|
||||
elif self.condition == SIMULTANEOUS:
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
|
||||
else:
|
||||
for dest in state[d.DESTINATION]:
|
||||
if dest.was_reached():
|
||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||
if dest.bound_entity:
|
||||
if dest.bound_entity == agent:
|
||||
pass
|
||||
else:
|
||||
dest.unmark_as_reached()
|
||||
return [DoneResult(f'{dest}_unmarked_as_reached',
|
||||
validity=c.NOT_VALID, entity=dest)]
|
||||
else:
|
||||
pass
|
||||
dest.unmark_as_reached()
|
||||
state.print(f'{dest} unmarked as reached, not all targets are reached in parallel.')
|
||||
else:
|
||||
pass
|
||||
return [DoneResult(f'all_unmarked_as_reached', validity=c.NOT_VALID)]
|
||||
else:
|
||||
raise ValueError('Check spelling of Parameter "condition".')
|
||||
|
||||
@ -104,10 +99,10 @@ class SpawnDestinationsPerAgent(Rule):
|
||||
|
||||
!!! This rule does not introduce any reward or done condition.
|
||||
|
||||
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||
"""
|
||||
super(Rule, self).__init__()
|
||||
super().__init__()
|
||||
self.per_agent_positions = dict()
|
||||
for agent_name, value in coords_or_quantity.items():
|
||||
if isinstance(value, int):
|
||||
@ -142,3 +137,25 @@ class SpawnDestinationsPerAgent(Rule):
|
||||
continue
|
||||
state[d.DESTINATION].add_item(destination)
|
||||
pass
|
||||
|
||||
|
||||
class SpawnDestinationOnAgent(Rule):
|
||||
def __init__(self):
|
||||
"""
|
||||
Special rule which spawns a single destination bound to a single agent just `below` him. Usefull for
|
||||
the `N-Puzzle` configurations.
|
||||
|
||||
!!! This rule does not introduce any reward or done condition.
|
||||
|
||||
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def on_reset(self, state: Gamestate):
|
||||
state.print("Spawn Desitnations")
|
||||
for agent in state[c.AGENT]:
|
||||
destination = Destination(agent.pos, bind_to=agent)
|
||||
state[d.DESTINATION].add_item(destination)
|
||||
assert len(state[d.DESTINATION].by_pos(agent.pos)) == 1
|
||||
pass
|
||||
|
Reference in New Issue
Block a user