mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-23 20:11:34 +02:00
Documentation
This commit is contained in:

committed by
Steffen Illium

parent
604c0c6f57
commit
855f53b406
@ -1,16 +1,17 @@
|
||||
from typing import Union
|
||||
|
||||
import marl_factory_grid.modules.destinations.constants
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
|
||||
class DestAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Attempts to wait at destination.
|
||||
"""
|
||||
super().__init__(d.DESTINATION, d.REWARD_WAIT_VALID, d.REWARD_WAIT_FAIL)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
|
@ -9,24 +9,37 @@ from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
class Destination(Entity):
|
||||
|
||||
def was_reached(self):
|
||||
return self._was_reached
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return d.DEST_SYMBOL
|
||||
|
||||
def __init__(self, *args, action_counts=0, **kwargs):
|
||||
"""
|
||||
Represents a destination in the environment that agents aim to reach.
|
||||
|
||||
"""
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self._was_reached = False
|
||||
self.action_counts = action_counts
|
||||
self._per_agent_actions = defaultdict(lambda: 0)
|
||||
|
||||
def do_wait_action(self, agent: Agent):
|
||||
def do_wait_action(self, agent) -> bool:
|
||||
"""
|
||||
Performs a wait action for the given agent at the destination.
|
||||
|
||||
:param agent: The agent performing the wait action.
|
||||
:type agent: Agent
|
||||
|
||||
:return: Whether the action was valid or not.
|
||||
:rtype: bool
|
||||
"""
|
||||
self._per_agent_actions[agent.name] += 1
|
||||
return c.VALID
|
||||
|
||||
def has_just_been_reached(self, state):
|
||||
"""
|
||||
Checks if the destination has just been reached based on the current state.
|
||||
"""
|
||||
if self.was_reached():
|
||||
return False
|
||||
agent_at_position = any(state[c.AGENT].by_pos(self.pos))
|
||||
@ -38,6 +51,9 @@ class Destination(Entity):
|
||||
return agent_at_position or any(x >= self.action_counts for x in self._per_agent_actions.values())
|
||||
|
||||
def agent_did_action(self, agent: Agent):
|
||||
"""
|
||||
Internal usage, currently no usage.
|
||||
"""
|
||||
return self._per_agent_actions[agent.name] >= self.action_counts
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
@ -57,3 +73,6 @@ class Destination(Entity):
|
||||
|
||||
def unmark_as_reached(self):
|
||||
self._was_reached = False
|
||||
|
||||
def was_reached(self) -> bool:
|
||||
return self._was_reached
|
||||
|
@ -5,13 +5,30 @@ from marl_factory_grid.modules.destinations.entitites import Destination
|
||||
class Destinations(Collection):
|
||||
_entity = Destination
|
||||
|
||||
var_is_blocking_light = False
|
||||
var_can_collide = False
|
||||
var_can_move = False
|
||||
var_has_position = True
|
||||
var_can_be_bound = True
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
A collection of destinations.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -11,18 +11,17 @@ from marl_factory_grid.modules.destinations import constants as d
|
||||
from marl_factory_grid.modules.destinations.entitites import Destination
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
|
||||
ANY = 'any'
|
||||
ALL = 'all'
|
||||
SIMULTANOIUS = 'simultanious'
|
||||
CONDITIONS =[ALL, ANY, SIMULTANOIUS]
|
||||
ANY = 'any'
|
||||
ALL = 'all'
|
||||
SIMULTANEOUS = 'simultanious'
|
||||
CONDITIONS = [ALL, ANY, SIMULTANEOUS]
|
||||
|
||||
|
||||
class DestinationReachReward(Rule):
|
||||
|
||||
def __init__(self, dest_reach_reward=d.REWARD_DEST_REACHED):
|
||||
"""
|
||||
This rule introduces the basic functionality, so that targts (Destinations) can be reached and marked as such.
|
||||
This rule introduces the basic functionality, so that targets (Destinations) can be reached and marked as such.
|
||||
Additionally, rewards are reported.
|
||||
|
||||
:type dest_reach_reward: float
|
||||
@ -62,7 +61,7 @@ class DoneAtDestinationReach(DestinationReachReward):
|
||||
This rule triggers and sets the done flag if ALL Destinations have been reached.
|
||||
|
||||
:type reward_at_done: float
|
||||
:param reward_at_done: Specifies the reward, agent get, whenn all destinations are reached.
|
||||
:param reward_at_done: Specifies the reward, agent get, when all destinations are reached.
|
||||
:type dest_reach_reward: float
|
||||
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
|
||||
"""
|
||||
@ -78,7 +77,7 @@ class DoneAtDestinationReach(DestinationReachReward):
|
||||
elif self.condition == ALL:
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
||||
elif self.condition == SIMULTANOIUS:
|
||||
elif self.condition == SIMULTANEOUS:
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
||||
else:
|
||||
@ -101,13 +100,13 @@ class DoneAtDestinationReach(DestinationReachReward):
|
||||
class SpawnDestinationsPerAgent(Rule):
|
||||
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int] | int]]):
|
||||
"""
|
||||
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
|
||||
Usefull for introducing specialists, etc. ..
|
||||
Special rule, that spawn destinations, that are bound to a single agent a fixed set of positions.
|
||||
Useful for introducing specialists, etc. ..
|
||||
|
||||
!!! This rule does not introduce any reward or done condition.
|
||||
|
||||
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
|
||||
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
|
||||
"""
|
||||
super(Rule, self).__init__()
|
||||
self.per_agent_positions = dict()
|
||||
|
Reference in New Issue
Block a user