mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
101 lines
2.6 KiB
Python
101 lines
2.6 KiB
Python
import abc
|
|
from typing import Union
|
|
|
|
from mfg_package.environment import rewards as r, constants as c
|
|
from mfg_package.utils.helpers import MOVEMAP
|
|
from mfg_package.utils.results import ActionResult
|
|
|
|
|
|
class Action(abc.ABC):
|
|
|
|
@property
|
|
def name(self):
|
|
return self._identifier
|
|
|
|
@abc.abstractmethod
|
|
def __init__(self, identifier: str):
|
|
self._identifier = identifier
|
|
|
|
@abc.abstractmethod
|
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
|
return
|
|
|
|
def __repr__(self):
|
|
return f'Action[{self._identifier}]'
|
|
|
|
|
|
class Noop(Action):
|
|
|
|
def __init__(self):
|
|
super().__init__(c.NOOP)
|
|
|
|
def do(self, entity, *_) -> Union[None, ActionResult]:
|
|
return ActionResult(identifier=self._identifier, validity=c.VALID,
|
|
reward=r.NOOP, entity=entity)
|
|
|
|
|
|
class Move(Action, abc.ABC):
|
|
|
|
@abc.abstractmethod
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def do(self, entity, env):
|
|
new_pos = self._calc_new_pos(entity.pos)
|
|
if next_tile := env[c.FLOOR].by_pos(new_pos):
|
|
# noinspection PyUnresolvedReferences
|
|
valid = entity.move(next_tile)
|
|
else:
|
|
valid = c.NOT_VALID
|
|
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
|
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
|
|
|
def _calc_new_pos(self, pos):
|
|
x_diff, y_diff = MOVEMAP[self._identifier]
|
|
return pos[0] + x_diff, pos[1] + y_diff
|
|
|
|
|
|
class North(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.NORTH, *args, **kwargs)
|
|
|
|
|
|
class NorthEast(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.NORTHEAST, *args, **kwargs)
|
|
|
|
|
|
class East(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.EAST, *args, **kwargs)
|
|
|
|
|
|
class SouthEast(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.SOUTHEAST, *args, **kwargs)
|
|
|
|
|
|
class South(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.SOUTH, *args, **kwargs)
|
|
|
|
|
|
class SouthWest(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.SOUTHWEST, *args, **kwargs)
|
|
|
|
|
|
class West(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.WEST, *args, **kwargs)
|
|
|
|
|
|
class NorthWest(Move):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(c.NORTHWEST, *args, **kwargs)
|
|
|
|
|
|
Move4 = [North, East, South, West]
|
|
# noinspection PyTypeChecker
|
|
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|