mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-20 05:56:07 +01:00
Results resolved. Small ADjustments
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
@@ -13,45 +14,50 @@ class Action(abc.ABC):
|
||||
return self._identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, identifier: str):
|
||||
def __init__(self, identifier: str, default_valid_reward: float, default_fail_reward: float,
|
||||
valid_reward: float | None = None, fail_reward: float | None = None):
|
||||
self.fail_reward = fail_reward if fail_reward is not None else default_fail_reward
|
||||
self.valid_reward = valid_reward if valid_reward is not None else default_valid_reward
|
||||
self._identifier = identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
print()
|
||||
return
|
||||
validity = bool(random.choice([0, 1]))
|
||||
return self.get_result(validity, entity)
|
||||
|
||||
def __repr__(self):
|
||||
return f'Action[{self._identifier}]'
|
||||
|
||||
def get_result(self, validity, entity):
|
||||
reward = self.valid_reward if validity else self.fail_reward
|
||||
return ActionResult(self.__class__.__name__, validity, reward=reward, entity=entity)
|
||||
|
||||
|
||||
class Noop(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(c.NOOP)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.NOOP, r.NOOP, r.NOOP, **kwargs)
|
||||
|
||||
def do(self, entity, *_) -> Union[None, ActionResult]:
|
||||
return ActionResult(identifier=self._identifier, validity=c.VALID,
|
||||
reward=r.NOOP, entity=entity)
|
||||
return self.get_result(c.VALID, entity)
|
||||
|
||||
|
||||
class Move(Action, abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, identifier, **kwargs):
|
||||
super().__init__(identifier, r.MOVEMENTS_VALID, r.MOVEMENTS_FAIL, **kwargs)
|
||||
|
||||
def do(self, entity, state):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if state.check_move_validity(entity, new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
move_validity = entity.move(new_pos, state)
|
||||
reward = r.MOVEMENTS_VALID if move_validity else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=move_validity, reward=reward)
|
||||
move_really_was_valid = entity.move(new_pos, state)
|
||||
return self.get_result(move_really_was_valid, entity)
|
||||
else: # There is no place to go, propably collision
|
||||
# This is currently handeld by the WatchCollisions rule, so that it can be switched on and off by conf.yml
|
||||
# return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.COLLISION)
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID)
|
||||
return self.get_result(c.NOT_VALID, entity)
|
||||
|
||||
def _calc_new_pos(self, pos):
|
||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||
@@ -59,43 +65,43 @@ class Move(Action, abc.ABC):
|
||||
|
||||
|
||||
class North(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTH, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.NORTH, **kwargs)
|
||||
|
||||
|
||||
class NorthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHEAST, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.NORTHEAST, **kwargs)
|
||||
|
||||
|
||||
class East(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.EAST, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.EAST, **kwargs)
|
||||
|
||||
|
||||
class SouthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHEAST, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.SOUTHEAST, **kwargs)
|
||||
|
||||
|
||||
class South(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTH, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.SOUTH, **kwargs)
|
||||
|
||||
|
||||
class SouthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHWEST, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.SOUTHWEST, **kwargs)
|
||||
|
||||
|
||||
class West(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.WEST, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.WEST, **kwargs)
|
||||
|
||||
|
||||
class NorthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHWEST, *args, **kwargs)
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(c.NORTHWEST, **kwargs)
|
||||
|
||||
|
||||
Move4 = [North, East, South, West]
|
||||
|
||||
Reference in New Issue
Block a user