Description and better naming scheme for the Destination-Module

This commit is contained in:
Steffen Illium 2023-10-26 16:43:40 +02:00
parent 20068e8e1b
commit ce4108380f
5 changed files with 83 additions and 21 deletions

View File

@ -1,4 +1,4 @@
from .actions import DestAction from .actions import DestAction
from .entitites import Destination from .entitites import Destination
from .groups import Destinations from .groups import Destinations
from .rules import DestinationReachAll, DestinationSpawn from .rules import DoneAtDestinationReachAll, SpawnDestinations

View File

@ -1,9 +1,10 @@
from typing import Union from typing import Union
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.destinations import constants as d, rewards as r from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -20,4 +21,4 @@ class DestAction(Action):
valid = c.NOT_VALID valid = c.NOT_VALID
state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed') state.print(f'{entity.name} just tried to do_wait_action do_wait_action at {entity.pos} but failed')
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, return ActionResult(entity=entity, identifier=self._identifier, validity=valid,
reward=r.WAIT_VALID if valid else r.WAIT_FAIL) reward=marl_factory_grid.modules.destinations.constants.REWARD_WAIT_VALID if valid else marl_factory_grid.modules.destinations.constants.REWARD_WAIT_FAIL)

View File

@ -5,3 +5,9 @@ DEST_SYMBOL = 1
MODE_SINGLE = 'SINGLE' MODE_SINGLE = 'SINGLE'
MODE_GROUPED = 'GROUPED' MODE_GROUPED = 'GROUPED'
SPAWN_MODES = [MODE_SINGLE, MODE_GROUPED]
REWARD_WAIT_VALID: float = 0.1
REWARD_WAIT_FAIL: float = -0.1
REWARD_DEST_REACHED: float = 1.0
REWARD_DEST_DONE: float = 5.0

View File

@ -1,3 +0,0 @@
WAIT_VALID: float = 0.1
WAIT_FAIL: float = -0.1
DEST_REACHED: float = 5.0

View File

@ -1,18 +1,29 @@
import ast import ast
from random import shuffle from random import shuffle
from typing import List, Dict, Tuple from typing import List, Dict, Tuple
import marl_factory_grid.modules.destinations.constants
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d, rewards as r from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.destinations.entitites import Destination from marl_factory_grid.modules.destinations.entitites import Destination
class DestinationReachAll(Rule): class DestinationReachReward(Rule):
def __init__(self): def __init__(self, dest_reach_reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED):
super(DestinationReachAll, self).__init__() """
This rule introduces the basic functionality, so that targts (Destinations) can be reached and marked as such.
Additionally, rewards are reported.
:type dest_reach_reward: float
:param dest_reach_reward: Specifies the reward, agents get at destination reach.
"""
super(DestinationReachReward, self).__init__()
self.reward = dest_reach_reward
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
results = [] results = []
@ -33,32 +44,69 @@ class DestinationReachAll(Rule):
if reached: if reached:
state.print(f'{dest.name} is reached now, mark as reached...') state.print(f'{dest.name} is reached now, mark as reached...')
dest.mark_as_reached() dest.mark_as_reached()
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent)) results.append(TickResult(self.name, validity=c.VALID, reward=self.reward, entity=agent))
return results return results
class DoneAtDestinationReachAll(DestinationReachReward):
def __init__(self, reward_at_done=marl_factory_grid.modules.destinations.constants.REWARD_DEST_DONE, **kwargs):
"""
This rule triggers and sets the done flag if ALL Destinations have been reached.
:type reward_at_done: object
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
:type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
"""
super(DoneAtDestinationReachAll, self).__init__(**kwargs)
self.reward = reward_at_done
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]): if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
class DestinationReachAny(DestinationReachAll): class DoneAtDestinationReachAny(DestinationReachReward):
def __init__(self): def __init__(self, reward_at_done=d.REWARD_DEST_DONE, **kwargs):
super(DestinationReachAny, self).__init__() f"""
This rule triggers and sets the done flag if ANY Destinations has been reached.
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one.
:type reward_at_done: object
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached.
Default {d.REWARD_DEST_DONE}
:type dest_reach_reward: float
:param dest_reach_reward: Specify a single agents reward forreaching a single destination.
Default {d.REWARD_DEST_REACHED}
"""
super(DoneAtDestinationReachAny, self).__init__(**kwargs)
self.reward = reward_at_done
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]): if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)] return [DoneResult(self.name, validity=c.VALID, reward=marl_factory_grid.modules.destinations.constants.REWARD_DEST_REACHED)]
return [] return []
class DestinationSpawn(Rule): class SpawnDestinations(Rule):
def __init__(self, n_dests: int = 1, def __init__(self, n_dests: int = 1, spawn_mode: str = d.MODE_GROUPED):
spawn_mode: str = d.MODE_GROUPED): f"""
super(DestinationSpawn, self).__init__() Defines how destinations are initially spawned and respawned in addition.
!!! This rule introduces no kind of reward or Env.-Done condition!
:type n_dests: int
:param n_dests: How many destiantions should be maintained (and initally spawnewd) on the map?
:type spawn_mode: str
:param spawn_mode: One of {d.SPAWN_MODES}. {d.MODE_GROUPED}: Always wait for all Dstiantions do be gone,
then respawn after the given time. {d.MODE_SINGLE}: Just spawn every destination,
that has been reached, after the given time
"""
super(SpawnDestinations, self).__init__()
self.n_dests = n_dests self.n_dests = n_dests
self.spawn_mode = spawn_mode self.spawn_mode = spawn_mode
@ -82,8 +130,18 @@ class DestinationSpawn(Rule):
pass pass
class FixedDestinationSpawn(Rule): class SpawnDestinationsPerAgent(Rule):
def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]): def __init__(self, per_agent_positions: Dict[str, List[Tuple[int, int]]]):
"""
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
:type per_agent_positions: Dict[str, List[Tuple[int, int]]
:param per_agent_positions: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super(Rule, self).__init__() super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()} self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in per_agent_positions.items()}