mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 21:47:25 +01:00
added documentation for agents
This commit is contained in:
@@ -17,6 +17,16 @@ future_planning = 7
|
|||||||
class TSPBaseAgent(ABC):
|
class TSPBaseAgent(ABC):
|
||||||
|
|
||||||
def __init__(self, state, agent_i, static_problem: bool = True):
|
def __init__(self, state, agent_i, static_problem: bool = True):
|
||||||
|
"""
|
||||||
|
Abstract base class for agents in the environment.
|
||||||
|
|
||||||
|
:param state: The environment state
|
||||||
|
:type state:
|
||||||
|
:param agent_i: Index of the agent
|
||||||
|
:type agent_i: int
|
||||||
|
:param static_problem: Indicates whether the TSP is a static problem. (Default: True)
|
||||||
|
:type static_problem: bool
|
||||||
|
"""
|
||||||
self.static_problem = static_problem
|
self.static_problem = static_problem
|
||||||
self.local_optimization = True
|
self.local_optimization = True
|
||||||
self._env = state
|
self._env = state
|
||||||
@@ -26,9 +36,25 @@ class TSPBaseAgent(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, *_, **__) -> int:
|
def predict(self, *_, **__) -> int:
|
||||||
|
"""
|
||||||
|
Predicts the next action based on the environment state.
|
||||||
|
|
||||||
|
:return: Predicted action.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _use_door_or_move(self, door, target):
|
def _use_door_or_move(self, door, target):
|
||||||
|
"""
|
||||||
|
Helper method to decide whether to use a door or move towards a target.
|
||||||
|
|
||||||
|
:param door: Door entity.
|
||||||
|
:type door: Door
|
||||||
|
:param target: Target type. For example 'Dirt', 'Dropoff' or 'Destination'
|
||||||
|
:type target: str
|
||||||
|
|
||||||
|
:return: Action to perform (use door or move).
|
||||||
|
"""
|
||||||
if door.is_closed:
|
if door.is_closed:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = do.ACTION_DOOR_USE
|
action = do.ACTION_DOOR_USE
|
||||||
@@ -37,6 +63,15 @@ class TSPBaseAgent(ABC):
|
|||||||
return action
|
return action
|
||||||
|
|
||||||
def calculate_tsp_route(self, target_identifier):
|
def calculate_tsp_route(self, target_identifier):
|
||||||
|
"""
|
||||||
|
Calculate the TSP route to reach a target.
|
||||||
|
|
||||||
|
:param target_identifier: Identifier of the target entity
|
||||||
|
:type target_identifier: str
|
||||||
|
|
||||||
|
:return: TSP route
|
||||||
|
:rtype: List[int]
|
||||||
|
"""
|
||||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||||
if self.local_optimization:
|
if self.local_optimization:
|
||||||
nodes = \
|
nodes = \
|
||||||
@@ -55,6 +90,15 @@ class TSPBaseAgent(ABC):
|
|||||||
return route
|
return route
|
||||||
|
|
||||||
def _door_is_close(self, state):
|
def _door_is_close(self, state):
|
||||||
|
"""
|
||||||
|
Check if a door is close to the agent's position.
|
||||||
|
|
||||||
|
:param state: Current environment state.
|
||||||
|
:type state: Gamestate
|
||||||
|
|
||||||
|
:return: Closest door entity or None if no door is close.
|
||||||
|
:rtype: Door | None
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
||||||
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
for y in state.entities.pos_dict[x] if do.DOOR in y.name)
|
||||||
@@ -62,9 +106,27 @@ class TSPBaseAgent(ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _has_targets(self, target_identifier):
|
def _has_targets(self, target_identifier):
|
||||||
|
"""
|
||||||
|
Check if there are targets available in the environment.
|
||||||
|
|
||||||
|
:param target_identifier: Identifier of the target entity.
|
||||||
|
:type target_identifier: str
|
||||||
|
|
||||||
|
:return: True if there are targets, False otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
return bool(len([x for x in self._env.state[target_identifier] if x.pos != c.VALUE_NO_POS]) >= 1)
|
return bool(len([x for x in self._env.state[target_identifier] if x.pos != c.VALUE_NO_POS]) >= 1)
|
||||||
|
|
||||||
def _predict_move(self, target_identifier):
|
def _predict_move(self, target_identifier):
|
||||||
|
"""
|
||||||
|
Predict the next move based on the given target.
|
||||||
|
|
||||||
|
:param target_identifier: Identifier of the target entity.
|
||||||
|
:type target_identifier: str
|
||||||
|
|
||||||
|
:return: Predicted action.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
if self._has_targets(target_identifier):
|
if self._has_targets(target_identifier):
|
||||||
if self.static_problem:
|
if self.static_problem:
|
||||||
if not self._static_route:
|
if not self._static_route:
|
||||||
|
|||||||
@@ -8,9 +8,18 @@ future_planning = 7
|
|||||||
class TSPDirtAgent(TSPBaseAgent):
|
class TSPDirtAgent(TSPBaseAgent):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
|
||||||
|
"""
|
||||||
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
|
"""
|
||||||
|
Predicts the next action based on the presence of dirt in the environment.
|
||||||
|
|
||||||
|
:return: Predicted action.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
|
if self._env.state[di.DIRT].by_pos(self.state.pos) is not None:
|
||||||
# Translate the action_object to an integer to have the same output as any other model
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
action = di.CLEAN_UP
|
action = di.CLEAN_UP
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ MODE_BRING = 'Mode_Bring'
|
|||||||
class TSPItemAgent(TSPBaseAgent):
|
class TSPItemAgent(TSPBaseAgent):
|
||||||
|
|
||||||
def __init__(self, *args, mode=MODE_GET, **kwargs):
|
def __init__(self, *args, mode=MODE_GET, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes a TSPItemAgent that colects items in the environment, stores them in his inventory and drops them off
|
||||||
|
at a drop-off location.
|
||||||
|
|
||||||
|
:param mode: Mode of the agent, either MODE_GET or MODE_BRING.
|
||||||
|
"""
|
||||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
@@ -46,6 +52,12 @@ class TSPItemAgent(TSPBaseAgent):
|
|||||||
return action_obj
|
return action_obj
|
||||||
|
|
||||||
def _choose(self):
|
def _choose(self):
|
||||||
|
"""
|
||||||
|
Internal Usage. Chooses the action based on the agent's mode and the environment state.
|
||||||
|
|
||||||
|
:return: Chosen action.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
target = i.DROP_OFF if self.mode == MODE_BRING else i.ITEM
|
target = i.DROP_OFF if self.mode == MODE_BRING else i.ITEM
|
||||||
if len(self._env.state[i.ITEM]) >= 1:
|
if len(self._env.state[i.ITEM]) >= 1:
|
||||||
action = self._predict_move(target)
|
action = self._predict_move(target)
|
||||||
|
|||||||
@@ -9,9 +9,20 @@ future_planning = 7
|
|||||||
class TSPTargetAgent(TSPBaseAgent):
|
class TSPTargetAgent(TSPBaseAgent):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes a TSPTargetAgent that aims to reach destinations.
|
||||||
|
"""
|
||||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _handle_doors(self, state):
|
def _handle_doors(self, state):
|
||||||
|
"""
|
||||||
|
Internal Usage. Handles the doors in the environment.
|
||||||
|
|
||||||
|
:param state: The current environment state.
|
||||||
|
:type state: marl_factory_grid.utils.states.Gamestate
|
||||||
|
:return: Closest door entity or None if no doors are close.
|
||||||
|
:rtype: marl_factory_grid.environment.entity.object.Entity or None
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
return next(y for x in state.entities.neighboring_positions(self.state.pos)
|
||||||
|
|||||||
@@ -8,8 +8,20 @@ future_planning = 7
|
|||||||
class TSPRandomAgent(TSPBaseAgent):
|
class TSPRandomAgent(TSPBaseAgent):
|
||||||
|
|
||||||
def __init__(self, n_actions, *args, **kwargs):
|
def __init__(self, n_actions, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initializes a TSPRandomAgent that performs random actions from within his action space.
|
||||||
|
|
||||||
|
:param n_actions: Number of possible actions.
|
||||||
|
:type n_actions: int
|
||||||
|
"""
|
||||||
super(TSPRandomAgent, self).__init__(*args, **kwargs)
|
super(TSPRandomAgent, self).__init__(*args, **kwargs)
|
||||||
self.n_action = n_actions
|
self.n_action = n_actions
|
||||||
|
|
||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
|
"""
|
||||||
|
Predicts the next action randomly.
|
||||||
|
|
||||||
|
:return: Predicted action.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
return randint(0, self.n_action - 1)
|
return randint(0, self.n_action - 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user