new rules, new spawn logic, small fixes, default and narrow corridor debugged

This commit is contained in:
Steffen Illium
2023-11-09 17:50:20 +01:00
parent 9b9c6e0385
commit 06a5130b25
67 changed files with 768 additions and 921 deletions

View File

@@ -2,8 +2,8 @@ import ast
from random import shuffle
from typing import List, Dict, Tuple
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
@@ -54,7 +54,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
"""
This rule triggers and sets the done flag if ALL Destinations have been reached.
:type reward_at_done: object
:type reward_at_done: float
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
@@ -65,7 +65,7 @@ class DoneAtDestinationReachAll(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
return [DoneResult(self.name, validity=c.NOT_VALID)]
class DoneAtDestinationReachAny(DestinationReachReward):
@@ -75,7 +75,7 @@ class DoneAtDestinationReachAny(DestinationReachReward):
This rule triggers and sets the done flag if ANY Destinations has been reached.
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
:type reward_at_done: object
:type reward_at_done: float
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
Default {d.REWARD_DEST_DONE}
:type dest_reach_reward: float
@@ -87,67 +87,29 @@ class DoneAtDestinationReachAny(DestinationReachReward):
def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)]
return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
return []
class SpawnDestinations(Rule):
def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
f"""
Defines how destinations are initially spawned and respawned in addition.
!!! This rule introduces no kind of reward or Env.-Done condition!
:type n_dests: int
:param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
:type spawn_mode: str
:param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
that has been reached, after the given time
"""
super(SpawnDestinations, self).__init__()
self.n_dests = n_dests
self.spawn_mode = spawn_mode
def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit
state[d.DESTINATION].trigger_destination_spawn(self.n_dests, state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
if n_dest_spawn := max(0, self.n_dests - len(state[d.DESTINATION])):
if self.spawn_mode == d.MODE_GROUPED and n_dest_spawn == self.n_dests:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
elif self.spawn_mode == d.MODE_SINGLE and n_dest_spawn:
validity = state[d.DESTINATION].trigger_destination_spawn(n_dest_spawn, state)
return [TickResult(self.name, validity=validity, entity=None, value=n_dest_spawn)]
else:
pass
class SpawnDestinationsPerAgent(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]):
"""
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
:type per_agent_positions: Dict[str, List[Tuple[int, int]]
:param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible
:type coords_or_quantity: Dict[str, List[Tuple[int, int]]
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()}
def on_init(self, state, lvl_map):
for (agent_name, position_list) in self.per_agent_positions.items():
agent = next(x for x in state[c.AGENT] if agent_name in x.name) # Fixme: Ugly AF
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
assert agent
position_list = position_list.copy()
shuffle(position_list)
while True:
@@ -155,7 +117,7 @@ class SpawnDestinationsPerAgent(Rule):
pos = position_list.pop()
except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
print(f'Check your agent palcement: {state[c.AGENT]} ... Exit ...')
print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
exit(9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent)