diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py index 3bbf997..83e5988 100644 --- a/marl_factory_grid/modules/destinations/__init__.py +++ b/marl_factory_grid/modules/destinations/__init__.py @@ -1,4 +1,4 @@ from .actions import DestAction from .entitites import Destination from .groups import Destinations -from .rules import DestinationReachAll, DestinationSpawn +from .rules import DoneAtDestinationReachAll, SpawnDestinations diff --git a/marl_factory_grid/modules/destinations/actions.py b/marl_factory_grid/modules/destinations/actions.py index 05c4955..13f7fe3 100644 --- a/marl_factory_grid/modules/destinations/actions.py +++ b/marl_factory_grid/modules/destinations/actions.py @@ -1,9 +1,10 @@ from typing import Union +import marl_factory_grid.modules.destinations.constants from marl_factory_grid.environment.actions import Action from marl_factory_grid.utils.results import ActionResult -from marl_factory_grid.modules.destinations import constants as d, rewards as r +from marl_factory_grid.modules.destinations import constants as d from marl_factory_grid.environment import constants as c @@ -20,4 +21,4 @@ class DestAction(Action): valid = c.NOT_VALID state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') return ActionResult(entity=entity, identifier=self._identifier, validity=valid, - reward=r.WAIT_VALID if valid else r.WAIT_FAIL) + reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL) diff --git a/marl_factory_grid/modules/destinations/constants.py b/marl_factory_grid/modules/destinations/constants.py index 97c3586..a67b6ee 100644 --- a/marl_factory_grid/modules/destinations/constants.py +++ b/marl_factory_grid/modules/destinations/constants.py @@ -5,3 +5,9 @@ DEST_SYMBOL = 1 MODE_SINGLE = 'SINGLE' MODE_GROUPED = 'GROUPED' +SPAWN_MODES = [MODE_SINGLE, MODE_GROUPED] + +REWARD_WAIT_VALID: float = 0.1 +REWARD_WAIT_FAIL: float = -0.1 +REWARD_DEST_REACHED: float = 1.0 +REWARD_DEST_DONE: float = 5.0 diff --git a/marl_factory_grid/modules/destinations/rewards.py b/marl_factory_grid/modules/destinations/rewards.py deleted file mode 100644 index 395988f..0000000 --- a/marl_factory_grid/modules/destinations/rewards.py +++ /dev/null @@ -1,3 +0,0 @@ -WAIT_VALID: float = 0.1 -WAIT_FAIL: float = -0.1 -DEST_REACHED: float = 5.0 \ No newline at end of file diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index 8773f2d..3133e5f 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -1,18 +1,29 @@ import ast from random import shuffle from typing import List, Dict, Tuple + +import marl_factory_grid.modules.destinations.constants from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c -from marl_factory_grid.modules.destinations import constants as d, rewards as r +from marl_factory_grid.modules.destinations import constants as d from marl_factory_grid.modules.destinations.entitites import Destination -class DestinationReachAll(Rule): +class DestinationReachReward(Rule): - def __init__(self): - super(DestinationReachAll, self).__init__() + def __init__(self, dest_reach_reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED): + """ + This rule introduces the basic functionality, so that targts (Destinations) can be reached and marked as such. + Additionally, rewards are reported. + + :type dest_reach_reward: float + :param dest_reach_reward: Specifies the reward, agents get at destination reach. + + """ + super(DestinationReachReward, self).__init__() + self.reward = dest_reach_reward def tick_step(self, state) -> List[TickResult]: results = [] @@ -33,32 +44,69 @@ class DestinationReachAll(Rule): if reached: state.print(f'{dest.name} is reached now, mark as reached...') dest.mark_as_reached() - results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) + results.append(TickResult(self.name, validity=c.VALID, reward=self.reward, entity=agent)) return results +class DoneAtDestinationReachAll(DestinationReachReward): + + def __init__(self, reward_at_done=marl_factory_grid.modules.destinations.constants.REWARD_DEST_DONE, **kwargs): + """ + This rule triggers and sets the done flag if ALL Destinations have been reached. + + :type reward_at_done: object + :param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached. + :type dest_reach_reward: float + :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. + """ + super(DoneAtDestinationReachAll, self).__init__(**kwargs) + self.reward = reward_at_done + def on_check_done(self, state) -> List[DoneResult]: if all(x.was_reached() for x in state[d.DESTINATION]): - return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] + return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] -class DestinationReachAny(DestinationReachAll): +class DoneAtDestinationReachAny(DestinationReachReward): - def __init__(self): - super(DestinationReachAny, self).__init__() + def __init__(self, reward_at_done=d.REWARD_DEST_DONE, **kwargs): + f""" + This rule triggers and sets the done flag if ANY Destinations has been reached. + !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. + + :type reward_at_done: object + :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached. + Default {d.REWARD_DEST_DONE} + :type dest_reach_reward: float + :param dest_reach_reward: Specify a single agents reward forreaching a single destination. + Default {d.REWARD_DEST_REACHED} + """ + super(DoneAtDestinationReachAny, self).__init__(**kwargs) + self.reward = reward_at_done def on_check_done(self, state) -> List[DoneResult]: if any(x.was_reached() for x in state[d.DESTINATION]): - return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] + return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)] return [] -class DestinationSpawn(Rule): +class SpawnDestinations(Rule): - def __init__(self, n_dests: int = 1, - spawn_mode: str = d.MODE_GROUPED): - super(DestinationSpawn, self).__init__() + def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED): + f""" + Defines how destinations are initially spawned and respawned in addition. + !!! This rule introduces no kind of reward or Env.-Done condition! + + :type n_dests: int + :param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map? + :type spawn_mode: str + :param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone, + then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination, + that has been reached, after the given time + + """ + super(SpawnDestinations, self).__init__() self.n_dests = n_dests self.spawn_mode = spawn_mode @@ -82,8 +130,18 @@ class DestinationSpawn(Rule): pass -class FixedDestinationSpawn(Rule): +class SpawnDestinationsPerAgent(Rule): def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): + """ + Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. + Usefull for introducing specialists, etc. .. + + !!! This rule does not introduce any reward or done condition. + + :type per_agent_positions: Dict[str, List[Tuple[int, int]] + :param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible + destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} + """ super(Rule, self).__init__() self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}