mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
148 lines
4.9 KiB
Python
148 lines
4.9 KiB
Python
import abc
|
|
import random
|
|
from typing import Union
|
|
|
|
from marl_factory_grid.environment import rewards as r, constants as c
|
|
from marl_factory_grid.utils.helpers import MOVEMAP
|
|
from marl_factory_grid.utils.results import ActionResult
|
|
|
|
TYPE_COLLISION = 'collision'
|
|
|
|
|
|
class Action(abc.ABC):
|
|
@property
|
|
def name(self):
|
|
return self._identifier
|
|
|
|
@abc.abstractmethod
|
|
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
|
|
valid_reward: float | None = None, fail_reward: float | None = None):
|
|
"""
|
|
Abstract base class representing an action that can be performed in the environment.
|
|
|
|
:param identifier: A unique identifier for the action.
|
|
:type identifier: str
|
|
:param default_valid_reward: Default reward for a valid action.
|
|
:type default_valid_reward: float
|
|
:param default_fail_reward: Default reward for a failed action.
|
|
:type default_fail_reward: float
|
|
:param valid_reward: Custom reward for a valid action (optional).
|
|
:type valid_reward: Union[float, optional]
|
|
:param fail_reward: Custom reward for a failed action (optional).
|
|
:type fail_reward: Union[float, optional]
|
|
"""
|
|
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
|
|
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
|
|
self._identifier = identifier
|
|
|
|
@abc.abstractmethod
|
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
|
"""
|
|
Let the :class:`marl_factory_grid.environment.entity.entity.Entity` perform the given.
|
|
|
|
:param entity: The entity to perform the action; mostly `marl_factory_grid.environment.entity.agent.Agent`
|
|
:param state: The current :class:'marl_factory_grid.utils.states.Gamestate'
|
|
:return:
|
|
"""
|
|
validity = bool(random.choice([0, 1]))
|
|
return self.get_result(validity, entity)
|
|
|
|
def __repr__(self):
|
|
return f'Action[{self._identifier}]'
|
|
|
|
def get_result(self, validity, entity, action_introduced_collision=False):
|
|
"""
|
|
Generate an ActionResult for the action based on its validity.
|
|
"""
|
|
reward = self.valid_reward if validity else self.fail_reward
|
|
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity,
|
|
action_introduced_collision=action_introduced_collision)
|
|
|
|
|
|
class Noop(Action):
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.NOOP, r.NOOP, r.NOOP, **kwargs)
|
|
|
|
def do(self, entity, *_) -> Union[None, ActionResult]:
|
|
return self.get_result(c.VALID, entity)
|
|
|
|
|
|
class Move(Action, abc.ABC):
|
|
|
|
@abc.abstractmethod
|
|
def __init__(self, identifier, **kwargs):
|
|
super().__init__(identifier, r.MOVEMENTS_VALID, r.MOVEMENTS_FAIL, **kwargs)
|
|
|
|
def do(self, entity, state):
|
|
new_pos = self._calc_new_pos(entity.pos)
|
|
collision = False
|
|
if state.check_move_validity(entity, new_pos):
|
|
valid = entity.move(new_pos, state)
|
|
# Aftermath Collision Check
|
|
if len([x for x in state.entities.by_pos(entity.pos) if x.var_can_collide]) > 1:
|
|
# The entity did move, but there was something to collide with...
|
|
collision = True
|
|
|
|
else:
|
|
# There is no place to go, propably collision
|
|
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
|
|
valid = c.NOT_VALID
|
|
collision = True
|
|
if valid:
|
|
state.print(f'{entity.name} just moved to {entity.pos}.')
|
|
else:
|
|
state.print(f'{entity.name} just tried to move to {new_pos} but either failed or hat a Collision.')
|
|
return self.get_result(valid, entity, action_introduced_collision=collision)
|
|
|
|
def _calc_new_pos(self, pos):
|
|
x_diff, y_diff = MOVEMAP[self._identifier]
|
|
return pos[0] + x_diff, pos[1] + y_diff
|
|
|
|
|
|
class North(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.NORTH, **kwargs)
|
|
|
|
|
|
class NorthEast(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.NORTHEAST, **kwargs)
|
|
|
|
|
|
class East(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.EAST, **kwargs)
|
|
|
|
|
|
class SouthEast(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.SOUTHEAST, **kwargs)
|
|
|
|
|
|
class South(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.SOUTH, **kwargs)
|
|
|
|
|
|
class SouthWest(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.SOUTHWEST, **kwargs)
|
|
|
|
|
|
class West(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.WEST, **kwargs)
|
|
|
|
|
|
class NorthWest(Move):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(c.NORTHWEST, **kwargs)
|
|
|
|
|
|
Move4 = [North, East, South, West]
|
|
# noinspection PyTypeChecker
|
|
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
|
|
|
ALL_BASEACTIONS = Move8 + [Noop]
|