renaming
0
marl_factory_grid/environment/__init__.py
Normal file
100
marl_factory_grid/environment/actions.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import abc
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
from marl_factory_grid.utils.helpers import MOVEMAP
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
|
||||
class Action(abc.ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, identifier: str):
|
||||
self._identifier = identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f'Action[{self._identifier}]'
|
||||
|
||||
|
||||
class Noop(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(c.NOOP)
|
||||
|
||||
def do(self, entity, *_) -> Union[None, ActionResult]:
|
||||
return ActionResult(identifier=self._identifier, validity=c.VALID,
|
||||
reward=r.NOOP, entity=entity)
|
||||
|
||||
|
||||
class Move(Action, abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do(self, entity, env):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if next_tile := env[c.FLOOR].by_pos(new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = entity.move(next_tile)
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
||||
|
||||
def _calc_new_pos(self, pos):
|
||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||
return pos[0] + x_diff, pos[1] + y_diff
|
||||
|
||||
|
||||
class North(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTH, *args, **kwargs)
|
||||
|
||||
|
||||
class NorthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHEAST, *args, **kwargs)
|
||||
|
||||
|
||||
class East(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.EAST, *args, **kwargs)
|
||||
|
||||
|
||||
class SouthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHEAST, *args, **kwargs)
|
||||
|
||||
|
||||
class South(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTH, *args, **kwargs)
|
||||
|
||||
|
||||
class SouthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHWEST, *args, **kwargs)
|
||||
|
||||
|
||||
class West(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.WEST, *args, **kwargs)
|
||||
|
||||
|
||||
class NorthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHWEST, *args, **kwargs)
|
||||
|
||||
|
||||
Move4 = [North, East, South, West]
|
||||
# noinspection PyTypeChecker
|
||||
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
0
marl_factory_grid/environment/assets/__init__.py
Normal file
BIN
marl_factory_grid/environment/assets/agent/adversary.png
Normal file
After Width: | Height: | Size: 8.3 KiB |
BIN
marl_factory_grid/environment/assets/agent/agent.png
Normal file
After Width: | Height: | Size: 3.3 KiB |
BIN
marl_factory_grid/environment/assets/agent/agent_collision.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
marl_factory_grid/environment/assets/agent/idle.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
marl_factory_grid/environment/assets/agent/invalid.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
marl_factory_grid/environment/assets/agent/move.png
Normal file
After Width: | Height: | Size: 5.8 KiB |
BIN
marl_factory_grid/environment/assets/agent/valid.png
Normal file
After Width: | Height: | Size: 5.6 KiB |
BIN
marl_factory_grid/environment/assets/wall.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
60
marl_factory_grid/environment/constants.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Names
|
||||
DANGER_ZONE = 'x' # Dange Zone tile _identifier for resolving the string based map files.
|
||||
DEFAULTS = 'Defaults'
|
||||
SELF = 'Self'
|
||||
PLACEHOLDER = 'Placeholder'
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'is_blocking_light'
|
||||
HAS_POSITION = 'has_position'
|
||||
HAS_NO_POSITION = 'has_no_position'
|
||||
ALL = 'All'
|
||||
|
||||
# Symbols (Read from map-files)
|
||||
SYMBOL_WALL = '#'
|
||||
SYMBOL_FLOOR = '-'
|
||||
|
||||
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||
VALUE_FREE_CELL = 0 # Free-Cell value used in observation
|
||||
VALUE_OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||
VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the environment (smth. is off-grid)
|
||||
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
||||
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||
|
||||
# Actions
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
||||
# Move Groups
|
||||
MOVE8 = 'Move8'
|
||||
MOVE4 = 'Move4'
|
||||
|
||||
# No-Action / Wait
|
||||
NOOP = 'Noop'
|
||||
|
||||
# Result Identifier
|
||||
MOVEMENTS_VALID = 'motion_valid'
|
||||
MOVEMENTS_FAIL = 'motion_not_valid'
|
0
marl_factory_grid/environment/entity/__init__.py
Normal file
76
marl_factory_grid/environment/entity/agent.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils import renderer
|
||||
from marl_factory_grid.utils.helpers import is_move
|
||||
from marl_factory_grid.utils.results import ActionResult, Result
|
||||
|
||||
|
||||
class Agent(Entity):
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def actions(self):
|
||||
return self._actions
|
||||
|
||||
@property
|
||||
def observations(self):
|
||||
return self._observations
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
def step_result(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return self._collection
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
|
||||
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
|
||||
super(Agent, self).__init__(*args, **kwargs)
|
||||
self.step_result = dict()
|
||||
self._actions = actions
|
||||
self._observations = observations
|
||||
self._state: Union[Result, None] = None
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def clear_temp_state(self):
|
||||
self._state = None
|
||||
return self
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
|
||||
return state_dict
|
||||
|
||||
def set_state(self, action_result):
|
||||
self._state = action_result
|
||||
|
||||
def render(self):
|
||||
i = next(idx for idx, x in enumerate(self._collection) if x.name == self.name)
|
||||
curr_state = self.state
|
||||
if curr_state.identifier == c.COLLISION:
|
||||
render_state = renderer.STATE_COLLISION
|
||||
elif curr_state.validity:
|
||||
if curr_state.identifier == c.NOOP:
|
||||
render_state = renderer.STATE_IDLE
|
||||
elif is_move(curr_state.identifier):
|
||||
render_state = renderer.STATE_MOVE
|
||||
else:
|
||||
render_state = renderer.STATE_VALID
|
||||
else:
|
||||
render_state = renderer.STATE_INVALID
|
||||
|
||||
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)
|
79
marl_factory_grid/environment/entity/entity.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import abc
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
|
||||
|
||||
class Entity(EnvObject, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
return self.pos != c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self._tile.pos
|
||||
|
||||
@property
|
||||
def tile(self):
|
||||
return self._tile
|
||||
|
||||
@property
|
||||
def last_tile(self):
|
||||
try:
|
||||
return self._last_tile
|
||||
except AttributeError:
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._last_tile = None
|
||||
return self._last_tile
|
||||
|
||||
@property
|
||||
def last_pos(self):
|
||||
try:
|
||||
return self.last_tile.pos
|
||||
except AttributeError:
|
||||
return c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
last_x, last_y = self.last_pos
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
|
||||
def move(self, next_tile):
|
||||
curr_tile = self.tile
|
||||
if not_same_tile := curr_tile != next_tile:
|
||||
if valid := next_tile.enter(self):
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
for observer in self.observers:
|
||||
observer.notify_change_pos(self)
|
||||
return valid
|
||||
return not_same_tile
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
18
marl_factory_grid/environment/entity/mixin.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class BoundEntityMixin:
|
||||
|
||||
@property
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return entity == self.bound_entity
|
||||
|
||||
def bind_to(self, entity):
|
||||
self._bound_entity = entity
|
126
marl_factory_grid/environment/entity/object.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class Object:
|
||||
|
||||
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if self._str_ident is not None:
|
||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||
return f'{self.__class__.__name__}#{self.identifier_int}'
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self._str_ident is not None:
|
||||
return self._str_ident
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
self._observers = []
|
||||
self._str_ident = str_ident
|
||||
self.identifier_int = self._identify_and_count_up()
|
||||
self._collection = None
|
||||
|
||||
if kwargs:
|
||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return other == self.identifier
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identifier)
|
||||
|
||||
def _identify_and_count_up(self):
|
||||
idx = Object._u_idx[self.__class__.__name__]
|
||||
Object._u_idx[self.__class__.__name__] += 1
|
||||
return idx
|
||||
|
||||
def set_collection(self, collection):
|
||||
self._collection = collection
|
||||
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
observer.notify_change_pos(self)
|
||||
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
|
||||
|
||||
class EnvObject(Object):
|
||||
|
||||
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
return self._collection.name or self.name
|
||||
except AttributeError:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
try:
|
||||
return self._collection.can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
try:
|
||||
return self._collection.has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
try:
|
||||
return self._collection.can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
45
marl_factory_grid/environment/entity/util.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||
from marl_factory_grid.environment.entity.object import Object, EnvObject
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ####################### Objects and Entitys ########################## #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class PlaceHolder(Object):
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._fill_value = fill_value
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._fill_value
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "PlaceHolder"
|
||||
|
||||
|
||||
class GlobalPosition(BoundEntityMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
if self._normalized:
|
||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||
else:
|
||||
return self.bound_entity.pos
|
||||
|
||||
def __init__(self, *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self._level_shape = math.sqrt(self.size)
|
||||
self._normalized = normalized
|
131
marl_factory_grid/environment/entity/wall_floor.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def neighboring_floor_pos(self):
|
||||
return [x.pos for x in self.neighboring_floor]
|
||||
|
||||
@property
|
||||
def neighboring_floor(self):
|
||||
if self._neighboring_floor:
|
||||
pass
|
||||
else:
|
||||
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
|
||||
for pos in h.POS_MASK.reshape(-1, 2)
|
||||
if not np.all(pos == [0, 0])]
|
||||
if x]
|
||||
return self._neighboring_floor
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
@property
|
||||
def guests_that_can_collide(self):
|
||||
return [x for x in self.guests if x.can_collide]
|
||||
|
||||
@property
|
||||
def guests(self):
|
||||
return self._guests.values()
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return any([x.is_blocking_pos for x in self.guests])
|
||||
|
||||
def __init__(self, pos, **kwargs):
|
||||
super(Floor, self).__init__(**kwargs)
|
||||
self._guests = dict()
|
||||
self.pos = tuple(pos)
|
||||
self._neighboring_floor: List[Floor] = list()
|
||||
self._blocked_by = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._guests)
|
||||
|
||||
def is_empty(self):
|
||||
return not len(self._guests)
|
||||
|
||||
def is_occupied(self):
|
||||
return bool(len(self._guests))
|
||||
|
||||
def enter(self, guest):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
|
||||
self._guests.update({guest.name: guest})
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def leave(self, guest):
|
||||
try:
|
||||
del self._guests[guest.name]
|
||||
except (ValueError, KeyError):
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}(@{self.pos})'
|
||||
|
||||
def summarize_state(self, **_):
|
||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||
|
||||
def render(self):
|
||||
return None
|
||||
|
||||
|
||||
class Wall(Floor):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(c.WALL, self.pos)
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return True
|
201
marl_factory_grid/environment/factory.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import shutil
|
||||
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from marl_factory_grid.utils.level_parser import LevelParser
|
||||
from marl_factory_grid.utils.observation_builder import OBSBuilder
|
||||
from marl_factory_grid.utils.config_parser import FactoryConfigParser
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
import marl_factory_grid.environment.constants as c
|
||||
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self.state[c.AGENT].action_space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
return self.state[c.AGENT].named_action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return self.obs_builder.observation_space(self.state)
|
||||
|
||||
@property
|
||||
def named_observation_space(self):
|
||||
return self.obs_builder.named_observation_space(self.state)
|
||||
|
||||
@property
|
||||
def params(self) -> dict:
|
||||
import yaml
|
||||
config_path = Path(self._config_file)
|
||||
config_dict = yaml.safe_load(config_path.open())
|
||||
return config_dict
|
||||
|
||||
@property
|
||||
def summarize_header(self):
|
||||
summary_dict = self._summarize_state(stateless_entities=True)
|
||||
return summary_dict
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, config_file: Union[str, PathLike]):
|
||||
self._config_file = config_file
|
||||
self.conf = FactoryConfigParser(self._config_file)
|
||||
# Attribute Assignment
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use; unless required !
|
||||
|
||||
parsed_entities = self.conf.load_entities()
|
||||
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
|
||||
|
||||
# Init for later usage:
|
||||
self.state: Gamestate
|
||||
self.map: LevelParser
|
||||
self.obs_builder: OBSBuilder
|
||||
|
||||
# TODO: Reset ---> document this
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.state.entities[item]
|
||||
|
||||
def reset(self) -> (dict, dict):
|
||||
self.state = None
|
||||
|
||||
# Init entity:
|
||||
entities = self.map.do_init()
|
||||
|
||||
# Grab all rules:
|
||||
rules = self.conf.load_rules()
|
||||
|
||||
# Agents
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.state = Gamestate(entities, rules, self.conf.env_seed)
|
||||
|
||||
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
|
||||
self.state.entities.add_item({c.AGENT: agents})
|
||||
|
||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||
self.state.rules.do_all_init(self.state)
|
||||
|
||||
# Observations
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
||||
return self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
|
||||
def step(self, actions):
|
||||
|
||||
if not isinstance(actions, list):
|
||||
actions = [int(actions)]
|
||||
|
||||
# Apply rules, do actions, tick the state, etc...
|
||||
tick_result = self.state.tick(actions)
|
||||
|
||||
# Check Done Conditions
|
||||
done_results = self.state.check_done()
|
||||
|
||||
# Finalize
|
||||
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
||||
|
||||
info = reward_info
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
# TODO:
|
||||
# if self._record_episodes:
|
||||
# info.update(self._summarize_state())
|
||||
|
||||
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
info.update(reset_info)
|
||||
return None, [x for x in obs.values()], reward, done, info
|
||||
|
||||
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
|
||||
# Returns: Reward, Info
|
||||
rewards = defaultdict(lambda: 0.0)
|
||||
|
||||
# Gather per agent environment rewards and
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = defaultdict(lambda: 0.0)
|
||||
for result in chain(tick_results, done_check_results):
|
||||
if result.reward is not None:
|
||||
try:
|
||||
rewards[result.entity.name] += result.reward
|
||||
except AttributeError:
|
||||
rewards['global'] += result.reward
|
||||
infos = result.get_infos()
|
||||
for info in infos:
|
||||
assert isinstance(info.value, (float, int))
|
||||
combined_info_dict[info.identifier] += info.value
|
||||
|
||||
# Check Done Rule Results
|
||||
try:
|
||||
done_reason = next(x for x in done_check_results if x.validity)
|
||||
done = True
|
||||
self.state.print(f'Env done, Reason: {done_reason.name}.')
|
||||
except StopIteration:
|
||||
done = False
|
||||
|
||||
if self.conf.individual_rewards:
|
||||
global_rewards = rewards['global']
|
||||
del rewards['global']
|
||||
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
||||
reward = [x + global_rewards for x in reward]
|
||||
self.state.print(f"rewards are {rewards}")
|
||||
return reward, combined_info_dict, done
|
||||
else:
|
||||
reward = sum(rewards.values())
|
||||
self.state.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict, done
|
||||
|
||||
def start_recording(self):
|
||||
self.conf.do_record = True
|
||||
return self.conf.do_record
|
||||
|
||||
def stop_recording(self):
|
||||
self.conf.do_record = False
|
||||
return not self.conf.do_record
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
from marl_factory_grid.utils.renderer import Renderer
|
||||
global Renderer
|
||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=10)
|
||||
|
||||
render_entities = self.state.entities.render()
|
||||
if self.conf.pomdp_r:
|
||||
for render_entity in render_entities:
|
||||
if render_entity.name == c.AGENT:
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities)
|
||||
|
||||
def _summarize_state(self, stateless_entities=False):
|
||||
summary = {f'{REC_TAC}step': self.state.curr_step}
|
||||
|
||||
for entity_group in self.state:
|
||||
if entity_group.is_stateless == stateless_entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
if self.conf.verbose:
|
||||
print(string)
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(self._config_file, filepath)
|
0
marl_factory_grid/environment/groups/__init__.py
Normal file
29
marl_factory_grid/environment/groups/agents.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
|
||||
|
||||
class Agents(PositionMixin, EnvObjects):
|
||||
_entity = Agent
|
||||
is_blocking_light = False
|
||||
can_move = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(a.name, a) for a in self]
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
from gymnasium import spaces
|
||||
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
|
||||
return space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
named_space = dict()
|
||||
for agent in self:
|
||||
named_space[agent.name] = {action.name: idx for idx, action in enumerate(agent.actions)}
|
||||
return named_space
|
33
marl_factory_grid/environment/groups/env_objects.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
|
||||
|
||||
class EnvObjects(Objects):
|
||||
|
||||
_entity = EnvObject
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
has_position: bool = False
|
||||
can_move: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
return [x.encoding for x in self]
|
||||
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
super(EnvObjects, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
assert self.has_position or (len(self) <= self.size)
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
def summarize_states(self):
|
||||
return [entity.summarize_state() for entity in self.values()]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
del self[name]
|
63
marl_factory_grid/environment/groups/global_entities.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.utils.helpers import POS_MASK
|
||||
|
||||
|
||||
class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
@staticmethod
|
||||
def neighboring_positions(pos):
|
||||
return (POS_MASK + pos).reshape(-1, 2)
|
||||
|
||||
def get_near_pos(self, pos):
|
||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
return [y for x in self for y in x.render() if x is not None]
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._data.keys())
|
||||
|
||||
def __init__(self):
|
||||
self.pos_dict = defaultdict(list)
|
||||
super().__init__()
|
||||
|
||||
def iter_entities(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def add_items(self, items: Dict):
|
||||
return self.add_item(items)
|
||||
|
||||
def add_item(self, item: dict):
|
||||
assert_str = 'This group of entity has already been added!'
|
||||
assert not any([key for key in item.keys() if key in self.keys()]), assert_str
|
||||
self._data.update(item)
|
||||
for val in item.values():
|
||||
val.add_observer(self)
|
||||
return self
|
||||
|
||||
def __delitem__(self, name):
|
||||
assert_str = 'This group of entity does not exist in this collection!'
|
||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||
self[name]._observers.delete(self)
|
||||
for entity in self[name]:
|
||||
entity.del_observer(self)
|
||||
return super(Entities, self).__delitem__(name)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [y for x in self for y in x.obs_pairs]
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
return self.pos_dict[pos]
|
||||
# found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
# return found_entities
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
97
marl_factory_grid/environment/groups/mixins.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
||||
class PositionMixin:
|
||||
|
||||
_entity = Entity
|
||||
is_blocking_light: bool = True
|
||||
can_collide: bool = True
|
||||
has_position: bool = True
|
||||
|
||||
def render(self):
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
collection = cls(*args, **kwargs)
|
||||
entities = [cls._entity(tile, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
collection.add_items(entities)
|
||||
return collection
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], tiles.size, *args,
|
||||
entity_kwargs=entity_kwargs,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return next(e for e in self if e.pos == pos)
|
||||
except StopIteration:
|
||||
pass
|
||||
except ValueError:
|
||||
print()
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [e.pos for e in self]
|
||||
|
||||
def notify_del_entity(self, entity: Entity):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class IsBoundMixin:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
||||
|
||||
def bind(self, entity):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._bound_entity = entity
|
||||
return c.VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class HasBoundedMixin:
|
||||
|
||||
@property
|
||||
def obs_names(self):
|
||||
return [x.name for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
141
marl_factory_grid/environment/groups/objects.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
|
||||
class Objects:
|
||||
_entity = Object
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def render():
|
||||
return []
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(self.name, self)]
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
return [x.name for x in self]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = list()
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
|
||||
assert isinstance(item, self._entity), assert_str
|
||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||
self._data.update({item.name: item})
|
||||
item.set_collection(self)
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(item)
|
||||
return self
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
for entity in self:
|
||||
if observer in entity.observers:
|
||||
entity.del_observer(observer)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
for entity in self:
|
||||
if observer not in entity.observers:
|
||||
entity.add_observer(observer)
|
||||
|
||||
def __delitem__(self, name):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(name)
|
||||
# noinspection PyTypeChecker
|
||||
del self._data[name]
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
return self._data.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._data.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._data) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._data.values()) if i == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
try:
|
||||
return self._data[item]
|
||||
except KeyError:
|
||||
return None
|
||||
except TypeError:
|
||||
print('Ups')
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
entity.add_observer(self)
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
77
marl_factory_grid/environment/groups/utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class Combined(PositionMixin, EnvObjects):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{super().name}({self._ident or self._names})'
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return self._names
|
||||
|
||||
def __init__(self, names: List[str], *args, identifier: Union[None, str] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ident = identifier
|
||||
self._names = names or list()
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(name, None) for name in self.names]
|
||||
|
||||
|
||||
class GlobalPositions(HasBoundedMixin, EnvObjects):
|
||||
|
||||
_entity = GlobalPosition
|
||||
is_blocking_light = False,
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
||||
|
||||
def __init__(self, parsed_level):
|
||||
raise NotImplementedError('This needs a Rework')
|
||||
super(Zones, self).__init__()
|
||||
slices = list()
|
||||
self._accounting_zones = list()
|
||||
self._danger_zones = list()
|
||||
for symbol in np.unique(parsed_level):
|
||||
if symbol == c.VALUE_OCCUPIED_CELL:
|
||||
continue
|
||||
elif symbol == c.DANGER_ZONE:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._danger_zones.append(symbol)
|
||||
else:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._accounting_zones.append(symbol)
|
||||
|
||||
self._zone_slices = np.stack(slices)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def add_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
54
marl_factory_grid/environment/groups/wall_n_floors.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.environment.entity.wall_floor import Wall, Floor
|
||||
|
||||
|
||||
class Walls(PositionMixin, EnvObjects):
|
||||
_entity = Wall
|
||||
symbol = c.SYMBOL_WALL
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Walls, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_OCCUPIED_CELL
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
symbol = c.SYMBOL_FLOOR
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Floors, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_FREE_CELL
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_occupied()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def empty_tiles(self) -> List[Floor]:
|
||||
tiles = [tile for tile in self if tile.is_empty()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
4
marl_factory_grid/environment/rewards.py
Normal file
@@ -0,0 +1,4 @@
|
||||
MOVEMENTS_VALID: float = -0.001
|
||||
MOVEMENTS_FAIL: float = -0.05
|
||||
NOOP: float = -0.01
|
||||
COLLISION: float = -0.5
|
82
marl_factory_grid/environment/rules.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import abc
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
|
||||
|
||||
class Rule(abc.ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
return []
|
||||
|
||||
|
||||
class MaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
super().__init__()
|
||||
self.done_at_collisions = done_at_collisions
|
||||
self.curr_done = False
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
self.curr_done = False
|
||||
tiles_with_collisions = state.get_all_tiles_with_collisions()
|
||||
results = list()
|
||||
for tile in tiles_with_collisions:
|
||||
guests = tile.guests_that_can_collide
|
||||
if len(guests) >= 2:
|
||||
for i, guest in enumerate(guests):
|
||||
try:
|
||||
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
|
||||
validity=c.NOT_VALID, entity=self))
|
||||
except AttributeError:
|
||||
pass
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=r.COLLISION, validity=c.VALID))
|
||||
self.curr_done = True
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if self.curr_done and self.done_at_collisions:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|