major redesign ob observations and entittes
0
environment/__init__.py
Normal file
101
environment/actions.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import abc
|
||||
from typing import Union
|
||||
|
||||
from environment import rewards as r
|
||||
from environment import constants as c
|
||||
from environment.utils.helpers import MOVEMAP
|
||||
from environment.utils.results import ActionResult
|
||||
|
||||
|
||||
class Action(abc.ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, identifier: str):
|
||||
self._identifier = identifier
|
||||
|
||||
@abc.abstractmethod
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f'Action[{self._identifier}]'
|
||||
|
||||
|
||||
class Noop(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(c.NOOP)
|
||||
|
||||
def do(self, entity, *_) -> Union[None, ActionResult]:
|
||||
return ActionResult(identifier=self._identifier, validity=c.VALID,
|
||||
reward=r.NOOP, entity=entity)
|
||||
|
||||
|
||||
class Move(Action, abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def do(self, entity, env):
|
||||
new_pos = self._calc_new_pos(entity.pos)
|
||||
if next_tile := env[c.FLOOR].by_pos(new_pos):
|
||||
# noinspection PyUnresolvedReferences
|
||||
valid = entity.move(next_tile)
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
reward = r.MOVEMENTS_VALID if valid else r.MOVEMENTS_FAIL
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
||||
|
||||
def _calc_new_pos(self, pos):
|
||||
x_diff, y_diff = MOVEMAP[self._identifier]
|
||||
return pos[0] + x_diff, pos[1] + y_diff
|
||||
|
||||
|
||||
class North(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTH, *args, **kwargs)
|
||||
|
||||
|
||||
class NorthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHEAST, *args, **kwargs)
|
||||
|
||||
|
||||
class East(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.EAST, *args, **kwargs)
|
||||
|
||||
|
||||
class SouthEast(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHEAST, *args, **kwargs)
|
||||
|
||||
|
||||
class South(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTH, *args, **kwargs)
|
||||
|
||||
|
||||
class SouthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.SOUTHWEST, *args, **kwargs)
|
||||
|
||||
|
||||
class West(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.WEST, *args, **kwargs)
|
||||
|
||||
|
||||
class NorthWest(Move):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(c.NORTHWEST, *args, **kwargs)
|
||||
|
||||
|
||||
Move4 = [North, East, South, West]
|
||||
# noinspection PyTypeChecker
|
||||
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
0
environment/assets/__init__.py
Normal file
BIN
environment/assets/agent/adversary.png
Normal file
After Width: | Height: | Size: 8.3 KiB |
BIN
environment/assets/agent/agent.png
Normal file
After Width: | Height: | Size: 3.3 KiB |
BIN
environment/assets/agent/agent_collision.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
environment/assets/agent/idle.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
environment/assets/agent/invalid.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
environment/assets/agent/move.png
Normal file
After Width: | Height: | Size: 5.8 KiB |
BIN
environment/assets/agent/valid.png
Normal file
After Width: | Height: | Size: 5.6 KiB |
BIN
environment/assets/wall.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
60
environment/constants.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Names
|
||||
DANGER_ZONE = 'x' # Dange Zone tile _identifier for resolving the string based map files.
|
||||
DEFAULTS = 'Defaults'
|
||||
SELF = 'Self'
|
||||
PLACEHOLDER = 'Placeholder'
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and groups (groups).
|
||||
FLOORS = 'Floors' # Identifier of Floor-objects and groups (groups).
|
||||
WALL = 'Wall' # Identifier of Wall-objects and groups (groups).
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'is_blocking_light'
|
||||
HAS_POSITION = 'has_position'
|
||||
HAS_NO_POSITION = 'has_no_position'
|
||||
ALL = 'All'
|
||||
|
||||
# Symbols (Read from map-files)
|
||||
SYMBOL_WALL = '#'
|
||||
SYMBOL_FLOOR = '-'
|
||||
|
||||
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||
VALUE_FREE_CELL = 0 # Free-Cell value used in observation
|
||||
VALUE_OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||
VALUE_NO_POS = (-9999, -9999) # Invalid Position value used in the environment (smth. is off-grid)
|
||||
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and groups (groups).
|
||||
COLLISION = 'Collision' # Identifier to use in the context of collitions.
|
||||
LAST_POS = 'LAST_POS' # Identifiert for retrieving an enitites last pos.
|
||||
VALIDITY = 'VALIDITY' # Identifiert for retrieving the Validity of Action, Tick, etc. ...
|
||||
|
||||
# Actions
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
SOUTH = 'south'
|
||||
WEST = 'west'
|
||||
NORTHEAST = 'north_east'
|
||||
SOUTHEAST = 'south_east'
|
||||
SOUTHWEST = 'south_west'
|
||||
NORTHWEST = 'north_west'
|
||||
|
||||
# Move Groups
|
||||
MOVE8 = 'Move8'
|
||||
MOVE4 = 'Move4'
|
||||
|
||||
# No-Action / Wait
|
||||
NOOP = 'Noop'
|
||||
|
||||
# Result Identifier
|
||||
MOVEMENTS_VALID = 'motion_valid'
|
||||
MOVEMENTS_FAIL = 'motion_not_valid'
|
0
environment/entity/__init__.py
Normal file
76
environment/entity/agent.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from typing import List, Union
|
||||
|
||||
from environment import constants as c
|
||||
from environment.actions import Action
|
||||
from environment.entity.entity import Entity
|
||||
from environment.utils.render import RenderEntity
|
||||
from environment.utils import renderer
|
||||
from environment.utils.helpers import is_move
|
||||
from environment.utils.results import ActionResult, Result
|
||||
|
||||
|
||||
class Agent(Entity):
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def actions(self):
|
||||
return self._actions
|
||||
|
||||
@property
|
||||
def observations(self):
|
||||
return self._observations
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
def step_result(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return self._collection
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
|
||||
def __init__(self, actions: List[Action], observations: List[str], *args, **kwargs):
|
||||
super(Agent, self).__init__(*args, **kwargs)
|
||||
self.step_result = dict()
|
||||
self._actions = actions
|
||||
self._observations = observations
|
||||
self._state: Union[Result, None] = None
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
def clear_temp_state(self):
|
||||
self._state = None
|
||||
return self
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
|
||||
return state_dict
|
||||
|
||||
def set_state(self, action_result):
|
||||
self._state = action_result
|
||||
|
||||
def render(self):
|
||||
i = next(idx for idx, x in enumerate(self._collection) if x.name == self.name)
|
||||
curr_state = self.state
|
||||
if curr_state.identifier == c.COLLISION:
|
||||
render_state = renderer.STATE_COLLISION
|
||||
elif curr_state.validity:
|
||||
if curr_state.identifier == c.NOOP:
|
||||
render_state = renderer.STATE_IDLE
|
||||
elif is_move(curr_state.identifier):
|
||||
render_state = renderer.STATE_MOVE
|
||||
else:
|
||||
render_state = renderer.STATE_VALID
|
||||
else:
|
||||
render_state = renderer.STATE_INVALID
|
||||
|
||||
return RenderEntity(c.AGENT, self.pos, 1, 'none', render_state, i + 1, real_name=self.name)
|
79
environment/entity/entity.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import abc
|
||||
|
||||
from environment import constants as c
|
||||
from environment.entity.object import EnvObject
|
||||
from environment.utils.render import RenderEntity
|
||||
|
||||
|
||||
class Entity(EnvObject, abc.ABC):
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
return self.pos != c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return self._tile.pos
|
||||
|
||||
@property
|
||||
def tile(self):
|
||||
return self._tile
|
||||
|
||||
@property
|
||||
def last_tile(self):
|
||||
try:
|
||||
return self._last_tile
|
||||
except AttributeError:
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._last_tile = None
|
||||
return self._last_tile
|
||||
|
||||
@property
|
||||
def last_pos(self):
|
||||
try:
|
||||
return self.last_tile.pos
|
||||
except AttributeError:
|
||||
return c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
last_x, last_y = self.last_pos
|
||||
curr_x, curr_y = self.pos
|
||||
return last_x - curr_x, last_y - curr_y
|
||||
|
||||
def move(self, next_tile):
|
||||
curr_tile = self.tile
|
||||
if not_same_tile := curr_tile != next_tile:
|
||||
if valid := next_tile.enter(self):
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
for observer in self.observers:
|
||||
observer.notify_change_pos(self)
|
||||
return valid
|
||||
return not_same_tile
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
18
environment/entity/mixin.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
class BoundEntityMixin:
|
||||
|
||||
@property
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return entity == self.bound_entity
|
||||
|
||||
def bind_to(self, entity):
|
||||
self._bound_entity = entity
|
127
environment/entity/object.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
class Object:
|
||||
|
||||
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
if self._str_ident is not None:
|
||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||
return f'{self.__class__.__name__}#{self.identifier_int}'
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self._str_ident is not None:
|
||||
return self._str_ident
|
||||
else:
|
||||
return self.name
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
self._observers = []
|
||||
self._str_ident = str_ident
|
||||
self.identifier_int = self._identify_and_count_up()
|
||||
self._collection = None
|
||||
|
||||
if kwargs:
|
||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return other == self.identifier
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identifier)
|
||||
|
||||
def _identify_and_count_up(self):
|
||||
idx = Object._u_idx[self.__class__.__name__]
|
||||
Object._u_idx[self.__class__.__name__] += 1
|
||||
return idx
|
||||
|
||||
def set_collection(self, collection):
|
||||
self._collection = collection
|
||||
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
observer.notify_change_pos(self)
|
||||
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
|
||||
|
||||
class EnvObject(Object):
|
||||
|
||||
"""Objects that hold Information that are observable, but have no position on the env grid. Inventories etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
return self._collection.name or self.name
|
||||
except AttributeError:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
try:
|
||||
return self._collection.can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
try:
|
||||
return self._collection.has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
try:
|
||||
return self._collection.can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
45
environment/entity/util.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment.entity.mixin import BoundEntityMixin
|
||||
from environment.entity.object import Object, EnvObject
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ####################### Objects and Entitys ########################## #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class PlaceHolder(Object):
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._fill_value = fill_value
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._fill_value
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "PlaceHolder"
|
||||
|
||||
|
||||
class GlobalPosition(BoundEntityMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
if self._normalized:
|
||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||
else:
|
||||
return self.bound_entity.pos
|
||||
|
||||
def __init__(self, *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self._level_shape = math.sqrt(self.size)
|
||||
self._normalized = normalized
|
131
environment/entity/wall_floor.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment import constants as c
|
||||
from environment.entity.object import EnvObject
|
||||
from environment.utils.render import RenderEntity
|
||||
from environment.utils import helpers as h
|
||||
|
||||
|
||||
class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def neighboring_floor_pos(self):
|
||||
return [x.pos for x in self.neighboring_floor]
|
||||
|
||||
@property
|
||||
def neighboring_floor(self):
|
||||
if self._neighboring_floor:
|
||||
pass
|
||||
else:
|
||||
self._neighboring_floor = [x for x in [self._collection.by_pos(np.add(self.pos, pos))
|
||||
for pos in h.POS_MASK.reshape(-1, 2)
|
||||
if not np.all(pos == [0, 0])]
|
||||
if x]
|
||||
return self._neighboring_floor
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
@property
|
||||
def guests_that_can_collide(self):
|
||||
return [x for x in self.guests if x.can_collide]
|
||||
|
||||
@property
|
||||
def guests(self):
|
||||
return self._guests.values()
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self.pos[1]
|
||||
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return any([x.is_blocking_pos for x in self.guests])
|
||||
|
||||
def __init__(self, pos, **kwargs):
|
||||
super(Floor, self).__init__(**kwargs)
|
||||
self._guests = dict()
|
||||
self.pos = tuple(pos)
|
||||
self._neighboring_floor: List[Floor] = list()
|
||||
self._blocked_by = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._guests)
|
||||
|
||||
def is_empty(self):
|
||||
return not len(self._guests)
|
||||
|
||||
def is_occupied(self):
|
||||
return bool(len(self._guests))
|
||||
|
||||
def enter(self, guest):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
|
||||
self._guests.update({guest.name: guest})
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def leave(self, guest):
|
||||
try:
|
||||
del self._guests[guest.name]
|
||||
except (ValueError, KeyError):
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}(@{self.pos})'
|
||||
|
||||
def summarize_state(self, **_):
|
||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||
|
||||
def render(self):
|
||||
return None
|
||||
|
||||
|
||||
class Wall(Floor):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(c.WALL, self.pos)
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return True
|
201
environment/factory.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import shutil
|
||||
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from environment.utils.level_parser import LevelParser
|
||||
from environment.utils.observation_builder import OBSBuilder
|
||||
from environment.utils.config_parser import FactoryConfigParser
|
||||
from environment.utils import helpers as h
|
||||
import environment.constants as c
|
||||
|
||||
from environment.utils.states import Gamestate
|
||||
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
return self.state[c.AGENT].action_space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
return self.state[c.AGENT].named_action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
return self.obs_builder.observation_space(self.state)
|
||||
|
||||
@property
|
||||
def named_observation_space(self):
|
||||
return self.obs_builder.named_observation_space(self.state)
|
||||
|
||||
@property
|
||||
def params(self) -> dict:
|
||||
import yaml
|
||||
config_path = Path(self._config_file)
|
||||
config_dict = yaml.safe_load(config_path.open())
|
||||
return config_dict
|
||||
|
||||
@property
|
||||
def summarize_header(self):
|
||||
summary_dict = self._summarize_state(stateless_entities=True)
|
||||
return summary_dict
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, config_file: Union[str, PathLike]):
|
||||
self._config_file = config_file
|
||||
self.conf = FactoryConfigParser(self._config_file)
|
||||
# Attribute Assignment
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use it when not required !
|
||||
|
||||
parsed_entities = self.conf.load_entities()
|
||||
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
|
||||
|
||||
# Init for later usage:
|
||||
self.state: Gamestate
|
||||
self.map: LevelParser
|
||||
self.obs_builder: OBSBuilder
|
||||
|
||||
# TODO: Reset ---> document this
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.state.entities[item]
|
||||
|
||||
def reset(self) -> (dict, dict):
|
||||
self.state = None
|
||||
|
||||
# Init entity:
|
||||
entities = self.map.do_init()
|
||||
|
||||
# Grab all rules:
|
||||
rules = self.conf.load_rules()
|
||||
|
||||
# Agents
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.state = Gamestate(entities, rules, self.conf.env_seed)
|
||||
|
||||
agents = self.conf.load_agents(self.map.size, self[c.FLOOR].empty_tiles)
|
||||
self.state.entities.add_item({c.AGENT: agents})
|
||||
|
||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||
self.state.rules.do_all_init(self.state)
|
||||
|
||||
# Observations
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
||||
return self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
|
||||
def step(self, actions):
|
||||
|
||||
if not isinstance(actions, list):
|
||||
actions = [int(actions)]
|
||||
|
||||
# Apply rules, do actions, tick the state, etc...
|
||||
tick_result = self.state.tick(actions)
|
||||
|
||||
# Check Done Conditions
|
||||
done_results = self.state.check_done()
|
||||
|
||||
# Finalize
|
||||
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
||||
|
||||
info = reward_info
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
# TODO:
|
||||
# if self._record_episodes:
|
||||
# info.update(self._summarize_state())
|
||||
|
||||
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
|
||||
info.update(reset_info)
|
||||
return None, [x for x in obs.values()], reward, done, info
|
||||
|
||||
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):
|
||||
# Returns: Reward, Info
|
||||
rewards = defaultdict(lambda: 0.0)
|
||||
|
||||
# Gather per agent env rewards and
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = defaultdict(lambda: 0.0)
|
||||
for result in chain(tick_results, done_check_results):
|
||||
if result.reward is not None:
|
||||
try:
|
||||
rewards[result.entity.name] += result.reward
|
||||
except AttributeError:
|
||||
rewards['global'] += result.reward
|
||||
infos = result.get_infos()
|
||||
for info in infos:
|
||||
assert isinstance(info.value, (float, int))
|
||||
combined_info_dict[info.identifier] += info.value
|
||||
|
||||
# Check Done Rule Results
|
||||
try:
|
||||
done_reason = next(x for x in done_check_results if x.validity)
|
||||
done = True
|
||||
self.state.print(f'Env done, Reason: {done_reason.name}.')
|
||||
except StopIteration:
|
||||
done = False
|
||||
|
||||
if self.conf.individual_rewards:
|
||||
global_rewards = rewards['global']
|
||||
del rewards['global']
|
||||
reward = [rewards[agent.name] for agent in self.state[c.AGENT]]
|
||||
reward = [x + global_rewards for x in reward]
|
||||
self.state.print(f"rewards are {rewards}")
|
||||
return reward, combined_info_dict, done
|
||||
else:
|
||||
reward = sum(rewards.values())
|
||||
self.state.print(f"reward is {reward}")
|
||||
return reward, combined_info_dict, done
|
||||
|
||||
def start_recording(self):
|
||||
self.conf.do_record = True
|
||||
return self.conf.do_record
|
||||
|
||||
def stop_recording(self):
|
||||
self.conf.do_record = False
|
||||
return not self.conf.do_record
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
if not self._renderer: # lazy init
|
||||
from environment.utils.renderer import Renderer
|
||||
global Renderer
|
||||
self._renderer = Renderer(self.map.level_shape, view_radius=self.conf.pomdp_r, fps=20)
|
||||
|
||||
render_entities = self.state.entities.render()
|
||||
if self.conf.pomdp_r:
|
||||
for render_entity in render_entities:
|
||||
if render_entity.name == c.AGENT:
|
||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||
return self._renderer.render(render_entities)
|
||||
|
||||
def _summarize_state(self, stateless_entities=False):
|
||||
summary = {f'{REC_TAC}step': self.state.curr_step}
|
||||
|
||||
for entity_group in self.state:
|
||||
if entity_group.is_stateless == stateless_entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
if self.conf.verbose:
|
||||
print(string)
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(self._config_file, filepath)
|
0
environment/groups/__init__.py
Normal file
30
environment/groups/agents.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from environment.groups.env_objects import EnvObjects
|
||||
from environment.groups.mixins import PositionMixin
|
||||
from environment.entity.agent import Agent
|
||||
import environment.constants as c
|
||||
|
||||
|
||||
class Agents(PositionMixin, EnvObjects):
|
||||
_entity = Agent
|
||||
is_blocking_light = False
|
||||
can_move = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(a.name, a) for a in self]
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
from gymnasium import spaces
|
||||
space = spaces.Tuple([spaces.Discrete(len(x.actions)) for x in self])
|
||||
return space
|
||||
|
||||
@property
|
||||
def named_action_space(self):
|
||||
named_space = dict()
|
||||
for agent in self:
|
||||
named_space[agent.name] = {action.name: idx for idx, action in enumerate(agent.actions)}
|
||||
return named_space
|
33
environment/groups/env_objects.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from environment.groups.objects import Objects
|
||||
from environment.entity.object import EnvObject
|
||||
|
||||
|
||||
class EnvObjects(Objects):
|
||||
|
||||
_entity = EnvObject
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
has_position: bool = False
|
||||
can_move: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
return [x.encoding for x in self]
|
||||
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
super(EnvObjects, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
assert self.has_position or (len(self) <= self.size)
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
def summarize_states(self):
|
||||
return [entity.summarize_state() for entity in self.values()]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
del self[name]
|
64
environment/groups/global_entities.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Dict
|
||||
|
||||
from environment.groups.objects import Objects
|
||||
from environment.entity.entity import Entity
|
||||
from environment.utils.helpers import POS_MASK
|
||||
|
||||
|
||||
class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
@staticmethod
|
||||
def neighboring_positions(pos):
|
||||
return (POS_MASK + pos).reshape(-1, 2)
|
||||
|
||||
def get_near_pos(self, pos):
|
||||
return [y for x in itemgetter(*(tuple(x) for x in self.neighboring_positions(pos)))(self.pos_dict) for y in x]
|
||||
|
||||
def render(self):
|
||||
return [y for x in self for y in x.render() if x is not None]
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._data.keys())
|
||||
|
||||
def __init__(self):
|
||||
self.pos_dict = defaultdict(list)
|
||||
super().__init__()
|
||||
|
||||
def iter_entities(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def add_items(self, items: Dict):
|
||||
return self.add_item(items)
|
||||
|
||||
def add_item(self, item: dict):
|
||||
assert_str = 'This group of entity has already been added!'
|
||||
assert not any([key for key in item.keys() if key in self.keys()]), assert_str
|
||||
self._data.update(item)
|
||||
for val in item.values():
|
||||
val.add_observer(self)
|
||||
return self
|
||||
|
||||
def __delitem__(self, name):
|
||||
assert_str = 'This group of entity does not exist in this collection!'
|
||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||
self[name]._observers.delete(self)
|
||||
for entity in self[name]:
|
||||
entity.del_observer(self)
|
||||
return super(Entities, self).__delitem__(name)
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [y for x in self for y in x.obs_pairs]
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
return self.pos_dict[pos]
|
||||
# found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
# return found_entities
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [k for k, v in self.pos_dict.items() for _ in v]
|
102
environment/groups/mixins.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from abc import ABC
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment import constants as c
|
||||
|
||||
from environment.entity.entity import Entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
||||
class PositionMixin:
|
||||
|
||||
_entity = Entity
|
||||
is_blocking_light: bool = True
|
||||
can_collide: bool = True
|
||||
has_position: bool = True
|
||||
|
||||
def render(self):
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
collection = cls(*args, **kwargs)
|
||||
entities = [cls._entity(tile, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
collection.add_items(entities)
|
||||
return collection
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], tiles.size, *args,
|
||||
entity_kwargs=entity_kwargs,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super().__delitem__(name)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return next(e for e in self if e.pos == pos)
|
||||
except StopIteration:
|
||||
pass
|
||||
except ValueError:
|
||||
print()
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [e.pos for e in self]
|
||||
|
||||
def notify_del_entity(self, entity: Entity):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class IsBoundMixin:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}#{self._bound_entity.name}({self._data})'
|
||||
|
||||
def bind(self, entity):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._bound_entity = entity
|
||||
return c.VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker
|
||||
class HasBoundedMixin:
|
||||
|
||||
@property
|
||||
def obs_names(self):
|
||||
return [x.name for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
141
environment/groups/objects.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment.entity.object import Object
|
||||
|
||||
|
||||
class Objects:
|
||||
_entity = Object
|
||||
|
||||
@property
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def render():
|
||||
return []
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(self.name, self)]
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
return [x.name for x in self]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._data = defaultdict(lambda: None)
|
||||
self._observers = list()
|
||||
self.pos_dict = defaultdict(list)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
assert_str = f'All item names have to be of type {self._entity}, but were {item.__class__}.,'
|
||||
assert isinstance(item, self._entity), assert_str
|
||||
assert self._data[item.name] is None, f'{item.name} allready exists!!!'
|
||||
self._data.update({item.name: item})
|
||||
item.set_collection(self)
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(item)
|
||||
return self
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def del_observer(self, observer):
|
||||
self.observers.remove(observer)
|
||||
for entity in self:
|
||||
if observer in entity.observers:
|
||||
entity.del_observer(observer)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def add_observer(self, observer):
|
||||
self.observers.append(observer)
|
||||
for entity in self:
|
||||
if observer not in entity.observers:
|
||||
entity.add_observer(observer)
|
||||
|
||||
def __delitem__(self, name):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(name)
|
||||
# noinspection PyTypeChecker
|
||||
del self._data[name]
|
||||
|
||||
def add_items(self, items: List[_entity]):
|
||||
for item in items:
|
||||
self.add_item(item)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._data.keys()
|
||||
|
||||
def values(self):
|
||||
return self._data.values()
|
||||
|
||||
def items(self):
|
||||
return self._data.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._data.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._data) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._data.values()) if i == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
try:
|
||||
return self._data[item]
|
||||
except KeyError:
|
||||
return None
|
||||
except TypeError:
|
||||
print('Ups')
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
entity.add_observer(self)
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
78
environment/groups/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import numbers
|
||||
from typing import List, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment.groups.env_objects import EnvObjects
|
||||
from environment.groups.objects import Objects
|
||||
from environment.groups.mixins import HasBoundedMixin, PositionMixin
|
||||
from environment.entity.util import PlaceHolder, GlobalPosition
|
||||
from environment.utils import helpers as h
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
class Combined(PositionMixin, EnvObjects):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{super().name}({self._ident or self._names})'
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return self._names
|
||||
|
||||
def __init__(self, names: List[str], *args, identifier: Union[None, str] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ident = identifier
|
||||
self._names = names or list()
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(name, None) for name in self.names]
|
||||
|
||||
|
||||
class GlobalPositions(HasBoundedMixin, EnvObjects):
|
||||
|
||||
_entity = GlobalPosition
|
||||
is_blocking_light = False,
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
||||
|
||||
def __init__(self, parsed_level):
|
||||
raise NotImplementedError('This needs a Rework')
|
||||
super(Zones, self).__init__()
|
||||
slices = list()
|
||||
self._accounting_zones = list()
|
||||
self._danger_zones = list()
|
||||
for symbol in np.unique(parsed_level):
|
||||
if symbol == c.VALUE_OCCUPIED_CELL:
|
||||
continue
|
||||
elif symbol == c.DANGER_ZONE:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._danger_zones.append(symbol)
|
||||
else:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._accounting_zones.append(symbol)
|
||||
|
||||
self._zone_slices = np.stack(slices)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def add_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
56
environment/groups/wall_n_floors.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment import constants as c
|
||||
from environment.groups.env_objects import EnvObjects
|
||||
from environment.groups.mixins import PositionMixin
|
||||
from environment.entity.wall_floor import Wall, Floor
|
||||
|
||||
|
||||
class Walls(PositionMixin, EnvObjects):
|
||||
_entity = Wall
|
||||
symbol = c.SYMBOL_WALL
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Walls, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_OCCUPIED_CELL
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.add_items([cls._entity(pos) for pos in argwhere_coordinates])
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
symbol = c.SYMBOL_FLOOR
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Floors, self).__init__(*args, **kwargs)
|
||||
self._value = c.VALUE_FREE_CELL
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_occupied()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@property
|
||||
def empty_tiles(self) -> List[Floor]:
|
||||
tiles = [tile for tile in self if tile.is_empty()]
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
0
environment/logging/__init__.py
Normal file
64
environment/logging/envmonitor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from gymnasium import Wrapper
|
||||
|
||||
from environment.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from environment.factory import REC_TAC
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from plotting.compare_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(Wrapper):
|
||||
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
super(EnvMonitor, self).__init__(env)
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dict = dict()
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def step(self, action):
|
||||
obs_type, obs, reward, done, info = self.env.step(action)
|
||||
self._read_info(info)
|
||||
self._read_done(done)
|
||||
return obs_type, obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _read_info(self, info: dict):
|
||||
self._monitor_dict[len(self._monitor_dict)] = {
|
||||
key: val for key, val in info.items() if
|
||||
key not in ['terminal_observation', 'episode'] and not key.startswith(REC_TAC)}
|
||||
return
|
||||
|
||||
def _read_done(self, done):
|
||||
if done:
|
||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
|
||||
self._monitor_dict = dict()
|
||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
env_monitor_df = env_monitor_df.aggregate(
|
||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||
)
|
||||
env_monitor_df['episode'] = len(self._monitor_df)
|
||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
||||
else:
|
||||
pass
|
||||
return
|
||||
|
||||
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if auto_plotting_keys:
|
||||
plot_single_run(filepath, column_keys=auto_plotting_keys)
|
152
environment/logging/recorder.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from gymnasium import Wrapper
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import simplejson
|
||||
|
||||
from environment.factory import REC_TAC
|
||||
|
||||
|
||||
class EnvRecorder(Wrapper):
|
||||
|
||||
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
|
||||
super(EnvRecorder, self).__init__(env)
|
||||
self.filepath = filepath
|
||||
self.freq = freq
|
||||
self._recorder_dict = defaultdict(list)
|
||||
self._recorder_out_list = list()
|
||||
self._episode_counter = 1
|
||||
self._do_record_dict = defaultdict(lambda: False)
|
||||
if isinstance(entities, str):
|
||||
if entities.lower() == 'all':
|
||||
self._entities = None
|
||||
else:
|
||||
self._entities = [entities]
|
||||
else:
|
||||
self._entities = entities
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def reset(self):
|
||||
self._on_training_start()
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
assert self.start_recording()
|
||||
|
||||
def _read_info(self, env_idx, info: dict):
|
||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||
if self._entities:
|
||||
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
||||
self._recorder_dict[env_idx].append(info_dict)
|
||||
else:
|
||||
pass
|
||||
return True
|
||||
|
||||
def _read_done(self, env_idx, done):
|
||||
if done:
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||
'episode': self._episode_counter})
|
||||
self._recorder_dict[env_idx] = list()
|
||||
else:
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
step_result = self.unwrapped.step(actions)
|
||||
if self.do_record_episode(0):
|
||||
info = step_result[-1]
|
||||
self._read_info(0, info)
|
||||
if self._do_record_dict[0]:
|
||||
self._read_done(0, step_result[-2])
|
||||
return step_result
|
||||
|
||||
def finalize(self):
|
||||
self._on_training_end()
|
||||
return True
|
||||
|
||||
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||
only_deltas=True,
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
# cls.out_file.unlink(missing_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
if only_deltas:
|
||||
from deepdiff import DeepDiff, Delta
|
||||
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
||||
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||
]
|
||||
out_dict = {'episodes': diff_dict}
|
||||
|
||||
else:
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._episode_counter,
|
||||
'env_params': self.env.params,
|
||||
'header': self.env.summarize_header
|
||||
})
|
||||
try:
|
||||
simplejson.dump(out_dict, f, indent=4)
|
||||
except TypeError:
|
||||
print('Shit')
|
||||
|
||||
if save_occupation_map:
|
||||
a = np.zeros((15, 15))
|
||||
# noinspection PyTypeChecker
|
||||
for episode in out_dict['episodes']:
|
||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||
|
||||
b = list(df[['x', 'y']].to_records(index=False))
|
||||
|
||||
np.add.at(a, tuple(zip(*b)), 1)
|
||||
|
||||
# a = np.rot90(a)
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
hm = sns.heatmap(data=a)
|
||||
hm.set_title('Very Nice Heatmap')
|
||||
plt.show()
|
||||
|
||||
if save_trajectory_map:
|
||||
raise NotImplementedError('This has not yet been implemented.')
|
||||
|
||||
def do_record_episode(self, env_idx):
|
||||
if not self._recorder_dict[env_idx]:
|
||||
if self.freq:
|
||||
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
|
||||
else:
|
||||
self._do_record_dict[env_idx] = False
|
||||
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
|
||||
'Nothing will be recorded')
|
||||
self._episode_counter += 1
|
||||
else:
|
||||
pass
|
||||
return self._do_record_dict[env_idx]
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||
if self._do_record_dict[env_idx]:
|
||||
self._read_info(env_idx, info)
|
||||
dones = list(enumerate(self.locals.get('dones', [])))
|
||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
||||
for env_idx, done in dones:
|
||||
if self._do_record_dict[env_idx]:
|
||||
self._read_done(env_idx, done)
|
||||
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
for env_idx in range(len(self._recorder_dict)):
|
||||
if self._recorder_dict[env_idx]:
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||
'episode': self._episode_counter})
|
||||
pass
|
4
environment/rewards.py
Normal file
@@ -0,0 +1,4 @@
|
||||
MOVEMENTS_VALID: float = -0.001
|
||||
MOVEMENTS_FAIL: float = -0.05
|
||||
NOOP: float = -0.01
|
||||
COLLISION: float = -0.5
|
83
environment/rules.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import abc
|
||||
from typing import Union, List
|
||||
|
||||
from environment.utils.results import Result, TickResult, DoneResult, ActionResult
|
||||
from environment import constants as c
|
||||
from environment import rewards as r
|
||||
|
||||
|
||||
class Rule(abc.ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
return []
|
||||
|
||||
|
||||
class MaxStepsReached(Rule):
|
||||
|
||||
def __init__(self, max_steps: int = 500):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
if self.max_steps <= state.curr_step:
|
||||
return [DoneResult(validity=c.VALID, identifier=self.name, reward=0)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
super().__init__()
|
||||
self.done_at_collisions = done_at_collisions
|
||||
self.curr_done = False
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
self.curr_done = False
|
||||
tiles_with_collisions = state.get_all_tiles_with_collisions()
|
||||
results = list()
|
||||
for tile in tiles_with_collisions:
|
||||
guests = tile.guests_that_can_collide
|
||||
if len(guests) >= 2:
|
||||
for i, guest in enumerate(guests):
|
||||
try:
|
||||
guest.set_state(TickResult(identifier=c.COLLISION, reward=r.COLLISION,
|
||||
validity=c.NOT_VALID, entity=self))
|
||||
except AttributeError:
|
||||
pass
|
||||
results.append(TickResult(entity=guest, identifier=c.COLLISION,
|
||||
reward=r.COLLISION, validity=c.VALID))
|
||||
self.curr_done = True
|
||||
return results
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if self.curr_done and self.done_at_collisions:
|
||||
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=r.COLLISION)]
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
0
environment/utils/__init__.py
Normal file
120
environment/utils/config_parser.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from environment.groups.global_entities import Entities
|
||||
from environment.groups.agents import Agents
|
||||
from environment.entity.agent import Agent
|
||||
from environment.utils.helpers import locate_and_import_class
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENTS]
|
||||
|
||||
def __init__(self, config_path):
|
||||
self.config_path = Path(config_path)
|
||||
self.config = yaml.safe_load(config_path.open())
|
||||
self.do_record = False
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self['General'][item]
|
||||
|
||||
def _get_sub_list(self, primary_key: str, sub_key: str):
|
||||
return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items()
|
||||
} for x in self.config[primary_key]]
|
||||
|
||||
@property
|
||||
def agent_actions(self):
|
||||
return self._get_sub_list('Agents', "Actions")
|
||||
|
||||
@property
|
||||
def agent_observations(self):
|
||||
return self._get_sub_list('Agents', "Observations")
|
||||
|
||||
@property
|
||||
def rules(self):
|
||||
return self.config['Rules']
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return self.config['Agents']
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
return self.config['Entities']
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
|
||||
def load_entities(self):
|
||||
# entites = Entities()
|
||||
entity_classes = dict()
|
||||
entities = []
|
||||
if c.DEFAULTS in self.entities:
|
||||
entities.extend(self.default_entites)
|
||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||
|
||||
for entity in entities:
|
||||
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||
return entity_classes
|
||||
|
||||
def load_agents(self, size, free_tiles):
|
||||
agents = Agents(size)
|
||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||
for name in self.agents:
|
||||
# Actions
|
||||
actions = list()
|
||||
if c.DEFAULTS in self.agents[name]['Actions']:
|
||||
actions.extend(self.default_actions)
|
||||
actions.extend(x for x in self.agents[name]['Actions'] if x != c.DEFAULTS)
|
||||
parsed_actions = list()
|
||||
for action in actions:
|
||||
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
||||
class_or_classes = locate_and_import_class(action, folder_path)
|
||||
try:
|
||||
parsed_actions.extend(class_or_classes)
|
||||
except TypeError:
|
||||
parsed_actions.append(class_or_classes)
|
||||
parsed_actions = [x() for x in parsed_actions]
|
||||
|
||||
# Observation
|
||||
observations = list()
|
||||
if c.DEFAULTS in self.agents[name]['Observations']:
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
agent = Agent(parsed_actions, observations, free_tiles.pop(), str_ident=name)
|
||||
agents.add_item(agent)
|
||||
return agents
|
||||
|
||||
def load_rules(self):
|
||||
# entites = Entities()
|
||||
rules_classes = dict()
|
||||
rules = []
|
||||
if c.DEFAULTS in self.rules:
|
||||
for rule in self.default_rules:
|
||||
if rule not in rules:
|
||||
rules.append(rule)
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
|
||||
for rule in rules:
|
||||
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
272
environment/utils/helpers.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import importlib
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Union, Dict, List
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
|
||||
from environment import constants as c
|
||||
|
||||
"""
|
||||
This file is used for:
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
|
||||
LEVELS_DIR = 'modules/levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
# Not used anymore? Clean!
|
||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||
[[-1, 0], [0, 0], [1, 0]],
|
||||
[[-1, 1], [0, 1], [1, 1]]])
|
||||
|
||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
||||
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
||||
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
||||
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ObservationTranslator:
|
||||
|
||||
def __init__(self, this_named_observation_space: Dict[str, dict],
|
||||
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||
placeholder_fill_value: Union[int, str, None] = None):
|
||||
"""
|
||||
This is a helper class, which converts agent observations from joined environments.
|
||||
For example, agent trained in different environments may expect different observations.
|
||||
This class translates from larger observations spaces to smaller.
|
||||
A string _identifier based approach is used.
|
||||
Currently, it is not possible to mix different obs shapes.
|
||||
|
||||
|
||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||
:type this_named_observation_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||
self.random_fill = np.random.normal
|
||||
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||
self.random_fill = np.random.uniform
|
||||
else:
|
||||
raise ValueError('Please chooe between "uniform" or "normal" ("u", "n").')
|
||||
elif isinstance(placeholder_fill_value, int):
|
||||
raise NotImplementedError('"Future Work."')
|
||||
else:
|
||||
self.random_fill = None
|
||||
|
||||
self._this_named_obs_space = this_named_observation_space
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||
|
||||
def translate_observation(self, agent_idx: int, obs):
|
||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||
translation = dict()
|
||||
for name, idxs in target_obs_space.items():
|
||||
if name in self._this_named_obs_space:
|
||||
for target_idx, this_idx in zip(idxs, self._this_named_obs_space[name]):
|
||||
taken_slice = np.take(obs, [this_idx], axis=1 if obs.ndim == 4 else 0)
|
||||
translation[target_idx] = taken_slice
|
||||
elif random_fill := self.random_fill:
|
||||
for target_idx in idxs:
|
||||
translation[target_idx] = random_fill(size=obs.shape[:-3] + (1,) + obs.shape[-2:])
|
||||
else:
|
||||
for target_idx in idxs:
|
||||
translation[target_idx] = np.zeros(shape=(obs.shape[:-3] + (1,) + obs.shape[-2:]))
|
||||
|
||||
translation = dict(sorted(translation.items()))
|
||||
return np.concatenate(list(translation.values()), axis=-3)
|
||||
|
||||
def translate_observations(self, observations: List[ArrayLike]):
|
||||
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||
|
||||
def __call__(self, observations):
|
||||
return self.translate_observations(observations)
|
||||
|
||||
|
||||
class ActionTranslator:
|
||||
|
||||
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||
"""
|
||||
This is a helper class, which converts agent action spaces to a joined environments action space.
|
||||
For example, agent trained in different environments may have different action spaces.
|
||||
This class translates from smaller individual agent action spaces to larger joined spaces.
|
||||
A string _identifier based approach is used.
|
||||
|
||||
:param target_named_action_space: Joined `Named action space` for the current environment.
|
||||
:type target_named_action_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
||||
:type per_agent_named_action_space: Dict[str, dict]
|
||||
"""
|
||||
|
||||
self._target_named_action_space = target_named_action_space
|
||||
if isinstance(per_agent_named_action_space, (list, tuple)):
|
||||
self._per_agent_named_action_space = per_agent_named_action_space
|
||||
else:
|
||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||
|
||||
def translate_action(self, agent_idx: int, action: int):
|
||||
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||
translated_action = self._target_named_action_space[named_action]
|
||||
return translated_action
|
||||
|
||||
def translate_actions(self, actions: List[int]):
|
||||
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||
|
||||
def __call__(self, actions):
|
||||
return self.translate_actions(actions)
|
||||
|
||||
|
||||
# Utility functions
|
||||
def parse_level(path):
|
||||
"""
|
||||
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
||||
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
||||
|
||||
:param path: Path to the `level` or `map` file on harddrive.
|
||||
:type path: os.Pathlike
|
||||
|
||||
:return: The read string representation of the `level` or `map`
|
||||
:rtype: List[List[str]]
|
||||
"""
|
||||
with path.open('r') as lvl:
|
||||
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
||||
if len(set([len(line) for line in level])) > 1:
|
||||
raise AssertionError('Every row of the level string must be of equal length.')
|
||||
return level
|
||||
|
||||
|
||||
def one_hot_level(level, symbol: str):
|
||||
"""
|
||||
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
||||
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
||||
Can be changed to filter for any symbol.
|
||||
|
||||
:param level: String based level representation (list of lists, see function `parse_level`).
|
||||
:param symbol: List[List[str]]
|
||||
|
||||
:return: Binary numpy array
|
||||
:rtype: np.typing._array_like.ArrayLike
|
||||
"""
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
def is_move(action_name: str):
|
||||
return action_name in MOVEMAP.keys()
|
||||
|
||||
|
||||
def asset_str(agent):
|
||||
"""
|
||||
FIXME @ romue
|
||||
"""
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
if step_result := agent.step_result:
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif valid and not is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
elif valid and is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
:type: bool
|
||||
|
||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
"""Locate an object by name or dotted path, importing as necessary."""
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
folder_path = Path(folder_path)
|
||||
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
# possible_package_path = folder_path / '__init__.py'
|
||||
# package = str(possible_package_path) if possible_package_path.exists() else None
|
||||
all_found_modules = list()
|
||||
for module_path in module_paths:
|
||||
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
|
||||
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
|
||||
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
|
||||
'EnvObjects',]])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
|
||||
f'Check the {folder_path.name} name.\n'
|
||||
f'Possible Options are:\n{set(all_found_modules)}')
|
55
environment/utils/level_parser.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment.groups.global_entities import Entities
|
||||
from environment.groups.wall_n_floors import Walls, Floors
|
||||
from environment.utils import helpers as h
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
class LevelParser(object):
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
return self.pomdp_r * 2 + 1
|
||||
|
||||
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
|
||||
self.pomdp_r = pomdp_r
|
||||
self.e_p_dict = entity_parse_dict
|
||||
self._parsed_level = h.parse_level(Path(level_file_path))
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
|
||||
def do_init(self):
|
||||
entities = Entities()
|
||||
# Walls
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
|
||||
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
|
||||
entities.add_items({c.WALL: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||
|
||||
if hasattr(e_class, 'symbol'):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol)
|
||||
if np.any(level_array):
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
f'Check your level file!')
|
||||
else:
|
||||
e = e_class(self.size, **e_kwargs)
|
||||
entities.add_items({e.name: e})
|
||||
return entities
|
315
environment/utils/observation_builder.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from environment.groups.utils import Combined
|
||||
from environment.utils.states import Gamestate
|
||||
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
|
||||
default_obs = [c.WALLS, c.OTHERS]
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
|
||||
self.all_obs = dict()
|
||||
self.light_blockers = defaultdict(lambda: False)
|
||||
self.positional = defaultdict(lambda: False)
|
||||
self.non_positional = defaultdict(lambda: False)
|
||||
self.ray_caster = dict()
|
||||
|
||||
self.level_shape = level_shape
|
||||
self.pomdp_r = pomdp_r
|
||||
self.obs_shape = (self.pomdp_d, self.pomdp_d) if self.pomdp_r else self.level_shape
|
||||
self.size = np.prod(self.obs_shape)
|
||||
|
||||
self.obs_layers = dict()
|
||||
|
||||
self.build_structured_obs_block(state)
|
||||
self.curr_lightmaps = dict()
|
||||
|
||||
def build_structured_obs_block(self, state):
|
||||
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
|
||||
self.all_obs.update({key: obj for key, obj in state.entities.obs_pairs})
|
||||
|
||||
def observation_space(self, state):
|
||||
from gymnasium.spaces import Tuple, Box
|
||||
obsn = self.refresh_and_build_for_all(state)
|
||||
if len(state[c.AGENT]) == 1:
|
||||
space = Box(low=0, high=1, shape=next(x for x in obsn.values()).shape, dtype=np.float32)
|
||||
else:
|
||||
space = Tuple([Box(low=0, high=1, shape=obs.shape, dtype=np.float32) for obs in obsn.values()])
|
||||
return space
|
||||
|
||||
def named_observation_space(self, state):
|
||||
return self.refresh_and_build_for_all(state)
|
||||
|
||||
def refresh_and_build_for_all(self, state) -> (dict, dict):
|
||||
self.build_structured_obs_block(state)
|
||||
info = {}
|
||||
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}, info
|
||||
|
||||
def refresh_and_build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
|
||||
self.build_structured_obs_block(state)
|
||||
named_obs_dict = {}
|
||||
for agent in state[c.AGENT]:
|
||||
obs, names = self.build_for_agent(agent, state)
|
||||
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
||||
return named_obs_dict
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
try:
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
except KeyError:
|
||||
self._sort_and_name_observation_conf(agent)
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
|
||||
# Handle in-grid observations aka visible observations
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d)))
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
|
||||
pre_sort_obs = dict(pre_sort_obs)
|
||||
obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d))
|
||||
|
||||
for idx, l_name in enumerate(agent_want_obs):
|
||||
try:
|
||||
obs[idx] = pre_sort_obs[l_name]
|
||||
except KeyError:
|
||||
if c.COMBINED in l_name:
|
||||
if combined := [pre_sort_obs[x] for x in self.all_obs[f'{c.COMBINED}({agent.name})'].names
|
||||
if x in pre_sort_obs]:
|
||||
obs[idx] = np.sum(combined, axis=0)
|
||||
elif l_name == c.PLACEHOLDER:
|
||||
obs[idx] = self.all_obs[c.PLACEHOLDER]
|
||||
else:
|
||||
try:
|
||||
e = self.all_obs[l_name]
|
||||
except KeyError:
|
||||
try:
|
||||
e = self.all_obs[f'{l_name}({agent.name})']
|
||||
except KeyError:
|
||||
try:
|
||||
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}')
|
||||
|
||||
try:
|
||||
positional = e.has_position
|
||||
except AttributeError:
|
||||
positional = False
|
||||
if positional:
|
||||
# Seems to be not visible, so just skip it
|
||||
# obs[idx] = np.zeros((self.pomdp_d, self.pomdp_d))
|
||||
# All good
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
v = e.encodings
|
||||
except AttributeError:
|
||||
try:
|
||||
v = e.encoding
|
||||
except AttributeError:
|
||||
raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"')
|
||||
try:
|
||||
np.put(obs[idx], range(len(v)), v, mode='raise')
|
||||
except TypeError:
|
||||
np.put(obs[idx], 0, v, mode='raise')
|
||||
except IndexError:
|
||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||
|
||||
try:
|
||||
self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool)
|
||||
except KeyError:
|
||||
print()
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r)
|
||||
obs_layers = []
|
||||
|
||||
for obs_str in agent.observations:
|
||||
if isinstance(obs_str, dict):
|
||||
obs_str, vals = next(obs_str.items().__iter__())
|
||||
else:
|
||||
vals = None
|
||||
if obs_str == c.SELF:
|
||||
obs_layers.append(agent.name)
|
||||
elif obs_str == c.DEFAULTS:
|
||||
obs_layers.extend(self.default_obs)
|
||||
elif obs_str == c.COMBINED:
|
||||
if isinstance(vals, str):
|
||||
vals = [vals]
|
||||
names = list()
|
||||
for val in vals:
|
||||
if val == c.SELF:
|
||||
names.append(agent.name)
|
||||
elif val == c.OTHERS:
|
||||
names.extend([x.name for x in agent.collection if x.name != agent.name])
|
||||
else:
|
||||
names.append(val)
|
||||
combined = Combined(names, self.pomdp_r, identifier=agent.name)
|
||||
self.all_obs[combined.name] = combined
|
||||
obs_layers.append(combined.name)
|
||||
elif obs_str == c.OTHERS:
|
||||
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
|
||||
elif obs_str == c.AGENTS:
|
||||
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
|
||||
else:
|
||||
obs_layers.append(obs_str)
|
||||
self.obs_layers[agent.name] = obs_layers
|
||||
self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0],
|
||||
self.pomdp_d or self.level_shape[1]
|
||||
))
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||
]
|
||||
rot_M = np.stack(rot_M, 0)
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
@staticmethod
|
||||
def ray_block_cache(cache_dict, key, callback, ents):
|
||||
if key not in cache_dict:
|
||||
cache_dict[key] = callback()
|
||||
if any(True for e in ents.pos_dict[key] if e.is_blocking_light) and not cache_dict[key]:
|
||||
print()
|
||||
return cache_dict[key]
|
||||
|
||||
def visible_entities(self, entities):
|
||||
visible = list()
|
||||
cache_blocking = {}
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = entities.pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache(cache_blocking,
|
||||
(x, y),
|
||||
lambda: any(True for e in entities_hit if e.is_blocking_light),
|
||||
entities)
|
||||
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
cache_blocking,
|
||||
key,
|
||||
lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light),
|
||||
entities)
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
if hits or diag_hits:
|
||||
break
|
||||
rx, ry = x, y
|
||||
try:
|
||||
d = next(x for x in visible if 'Door' in x.name)
|
||||
v = [x for x in visible if tuple(np.subtract(x.pos, d.pos)) in [(1, 0), (0, 1), (-1, 0), (0, -1)] and x.name.startswith('Floor')]
|
||||
if len(v) > 2:
|
||||
pass
|
||||
except StopIteration:
|
||||
pass
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
x2, y2 = end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Determine how steep the line is
|
||||
is_steep = abs(dy) > abs(dx)
|
||||
|
||||
# Rotate line
|
||||
if is_steep:
|
||||
x1, y1 = y1, x1
|
||||
x2, y2 = y2, x2
|
||||
|
||||
# Swap start and end points if necessary and store swap state
|
||||
swapped = False
|
||||
if x1 > x2:
|
||||
x1, x2 = x2, x1
|
||||
y1, y2 = y2, y1
|
||||
swapped = True
|
||||
|
||||
# Recalculate differentials
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Calculate error
|
||||
error = int(dx / 2.0)
|
||||
ystep = 1 if y1 < y2 else -1
|
||||
|
||||
# Iterate over bounding box generating points between start and end
|
||||
y = y1
|
||||
points = []
|
||||
for x in range(int(x1), int(x2) + 1):
|
||||
coord = [y, x] if is_steep else [x, y]
|
||||
points.append(coord)
|
||||
error -= abs(dy)
|
||||
if error < 0:
|
||||
y += ystep
|
||||
error += dx
|
||||
|
||||
# Reverse the list if the coordinates were swapped
|
||||
if swapped:
|
||||
points.reverse()
|
||||
results.append(points)
|
||||
return results
|
16
environment/utils/render.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
value_operation: str = 'none'
|
||||
state: str = None
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
143
environment/utils/renderer.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from itertools import product
|
||||
import pygame
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from environment.utils.render import RenderEntity
|
||||
|
||||
AGENT: str = 'agent'
|
||||
STATE_IDLE: str = 'idle'
|
||||
STATE_MOVE: str = 'move'
|
||||
STATE_VALID: str = 'valid'
|
||||
STATE_INVALID: str = 'invalid'
|
||||
STATE_COLLISION: str = 'agent_collision'
|
||||
BLANK: str = 'blank'
|
||||
DOOR: str = 'door'
|
||||
OPACITY: str = 'opacity'
|
||||
SCALE: str = 'scale'
|
||||
|
||||
|
||||
class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
self.fps = fps
|
||||
self.grid_lines = grid_lines
|
||||
self.view_radius = view_radius
|
||||
pygame.init()
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||
self.fill_bg()
|
||||
|
||||
now = time.time()
|
||||
self.font = pygame.font.Font(None, 20)
|
||||
self.font.set_bold(True)
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
self.screen.fill(Renderer.BG_COLOR)
|
||||
if self.grid_lines:
|
||||
w, h = self.screen_size
|
||||
for x in range(0, w, self.cell_size):
|
||||
for y in range(0, h, self.cell_size):
|
||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity):
|
||||
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
r, c = entity.pos
|
||||
r, c = r - offset_r, c-offset_c
|
||||
|
||||
img = self.assets[entity.name.lower()]
|
||||
if entity.value_operation == OPACITY:
|
||||
img.set_alpha(255*entity.value)
|
||||
elif entity.value_operation == SCALE:
|
||||
re = img.get_rect()
|
||||
img = pygame.transform.smoothscale(
|
||||
img, (int(entity.value*re.width), int(entity.value*re.height))
|
||||
)
|
||||
o = self.cell_size//2
|
||||
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
||||
rect = img.get_rect()
|
||||
rect.centerx, rect.centery = c_, r_
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
rects = []
|
||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
||||
range(-self.view_radius, self.view_radius+1)):
|
||||
if view is not None:
|
||||
if bool(view[self.view_radius+j, self.view_radius+i]):
|
||||
visibility_rect = bp['dest'].copy()
|
||||
visibility_rect.centerx += i*self.cell_size
|
||||
visibility_rect.centery += j*self.cell_size
|
||||
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
|
||||
shape_surf.set_alpha(64)
|
||||
rects.append(dict(source=shape_surf, dest=visibility_rect))
|
||||
return rects
|
||||
|
||||
def render(self, entities):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
self.fill_bg()
|
||||
blits = deque()
|
||||
for entity in [x for x in entities]:
|
||||
bp = self.blit_params(entity)
|
||||
blits.append(bp)
|
||||
if entity.name.lower() == AGENT:
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if entity.state != BLANK:
|
||||
agent_state_blits = self.blit_params(
|
||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
||||
bp['dest'].center[1]))
|
||||
blits += [agent_state_blits, text_blit]
|
||||
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||
return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
renderer = Renderer(fps=2, cell_size=40)
|
||||
for pos_i in range(15):
|
||||
entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
|
||||
renderer.render([entity_1])
|
48
environment/utils/results.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from environment.entity.entity import Entity
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
identifier: str
|
||||
val_type: str
|
||||
value: Union[float, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
identifier: str
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: Union[Entity, None] = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in types
|
||||
if self.__getattribute__(t) is not None]
|
||||
|
||||
def __repr__(self):
|
||||
valid = "not " if not self.validity else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickResult(Result):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult(Result):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DoneResult(Result):
|
||||
pass
|
112
environment/utils/states.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environment.entity.wall_floor import Floor
|
||||
from environment.rules import Rule
|
||||
from environment.utils.results import Result
|
||||
from environment import constants as c
|
||||
|
||||
|
||||
class StepRules:
|
||||
def __init__(self, *args):
|
||||
if args:
|
||||
self.rules = list(args)
|
||||
else:
|
||||
self.rules = list()
|
||||
|
||||
def __repr__(self):
|
||||
return f'Rules{[x.name for x in self]}'
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.rules)
|
||||
|
||||
def append(self, item):
|
||||
assert isinstance(item, Rule)
|
||||
self.rules.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state):
|
||||
for rule in self.rules:
|
||||
if rule_init_printline := rule.on_init(state):
|
||||
state.print(rule_init_printline)
|
||||
return c.VALID
|
||||
|
||||
def tick_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_step_result := rule.tick_step(state):
|
||||
results.extend(tick_step_result)
|
||||
return results
|
||||
|
||||
def tick_pre_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_pre_step_result := rule.tick_post_step(state):
|
||||
results.extend(tick_pre_step_result)
|
||||
return results
|
||||
|
||||
def tick_post_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_post_step_result := rule.tick_post_step(state):
|
||||
results.extend(tick_post_step_result)
|
||||
return results
|
||||
|
||||
|
||||
class Gamestate(object):
|
||||
|
||||
@property
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.can_move]
|
||||
|
||||
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
self.entities = entitites
|
||||
self.NO_POS_TILE = Floor(c.VALUE_NO_POS)
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
self.verbose = verbose
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.entities[item]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(e for e in self.entities.values())
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
results = list()
|
||||
self.curr_step += 1
|
||||
|
||||
# Main Agent Step
|
||||
results.extend(self.rules.tick_pre_step_all(self))
|
||||
for idx, action_int in enumerate(actions):
|
||||
agent = self[c.AGENT][idx].clear_temp_state()
|
||||
action = agent.actions[action_int]
|
||||
action_result = action.do(agent, self)
|
||||
results.append(action_result)
|
||||
agent.set_state(action_result)
|
||||
results.extend(self.rules.tick_step_all(self))
|
||||
results.extend(self.rules.tick_post_step_all(self))
|
||||
return results
|
||||
|
||||
def print(self, string):
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
def check_done(self):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if on_check_done_result := rule.on_check_done(self):
|
||||
results.extend(on_check_done_result)
|
||||
return results
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.can_collide for x in e]) > 1]
|
||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
return tiles
|
27
environment/utils/utility_classes.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCombiner(object):
|
||||
|
||||
def __init__(self, *envs_cls):
|
||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||
|
||||
@staticmethod
|
||||
def combine_cls(name, *envs_cls):
|
||||
return type(name, envs_cls, {})
|
||||
|
||||
def build(self):
|
||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||
|
||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|