mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Everything is an object now
This commit is contained in:
@ -1,10 +1,11 @@
|
||||
import random
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Slice, Action
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall
|
||||
from environments.utility_classes import MovementProperties
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
@ -13,10 +14,6 @@ from environments.helpers import Constants as c
|
||||
class Register:
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
@ -25,7 +22,7 @@ class Register:
|
||||
def n(self):
|
||||
return len(self)
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register = dict()
|
||||
self._names = dict()
|
||||
|
||||
@ -35,17 +32,18 @@ class Register:
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def __add__(self, other: _accepted_objects):
|
||||
def register_item(self, other: _accepted_objects):
|
||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||
f'{self._accepted_objects}, ' \
|
||||
f'but were {other.__class__}.,'
|
||||
self._names.update({other.name: len(self._register)})
|
||||
self._register.update({len(self._register): other})
|
||||
new_idx = len(self._register)
|
||||
self._names.update({other.name: new_idx})
|
||||
self._register.update({new_idx: other})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: List[_accepted_objects]):
|
||||
for other in others:
|
||||
self + other
|
||||
self.register_item(other)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
@ -60,8 +58,9 @@ class Register:
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self._register[item]
|
||||
except KeyError:
|
||||
except KeyError as e:
|
||||
print('NO')
|
||||
print(e)
|
||||
raise
|
||||
|
||||
def by_name(self, item):
|
||||
@ -82,29 +81,66 @@ class Register:
|
||||
def get_idx(self, enum_obj: Enum):
|
||||
return self._names[enum_obj.name]
|
||||
|
||||
|
||||
class ObjectRegister(Register):
|
||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
||||
super(ObjectRegister, self).__init__(*args, **kwargs)
|
||||
self.is_per_agent = is_per_agent
|
||||
self.individual_slices = individual_slices
|
||||
self._level_shape = level_shape
|
||||
self._array = None
|
||||
|
||||
def register_item(self, other):
|
||||
super(ObjectRegister, self).register_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
else:
|
||||
if self.individual_slices:
|
||||
self._array = np.concatenate((self._array, np.zeros(1, *self._level_shape)))
|
||||
|
||||
|
||||
class EntityObjectRegister(ObjectRegister, ABC):
|
||||
|
||||
def as_array(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, **kwargs):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs) for i, tile in enumerate(tiles)]
|
||||
registered_obj = cls()
|
||||
registered_obj.register_additional_items(entities)
|
||||
return registered_obj
|
||||
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs)
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj = cls(*args)
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
|
||||
|
||||
class EntityRegister(Register):
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, **kwargs):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], *args, **kwargs)
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
return [agent.pos for agent in self]
|
||||
return list(self._tiles.keys())
|
||||
|
||||
def __init__(self):
|
||||
super(EntityRegister, self).__init__()
|
||||
@property
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
||||
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self._tiles = dict()
|
||||
self.is_blocking_light = is_blocking_light
|
||||
self.is_observable = is_observable
|
||||
|
||||
def __add__(self, other):
|
||||
super(EntityRegister, self).__add__(other)
|
||||
def register_item(self, other):
|
||||
super(EntityObjectRegister, self).register_item(other)
|
||||
self._tiles[other.pos] = other
|
||||
|
||||
def register_additional_items(self, others):
|
||||
for other in others:
|
||||
self.register_item(other)
|
||||
return self
|
||||
|
||||
def by_pos(self, pos):
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = tuple(pos)
|
||||
@ -114,9 +150,34 @@ class EntityRegister(Register):
|
||||
return None
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
|
||||
def by_pos(self, pos):
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return [x for x in self if x == pos][0]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def delete_item(self, item):
|
||||
self
|
||||
|
||||
|
||||
class Entities(Register):
|
||||
|
||||
_accepted_objects = Register
|
||||
_accepted_objects = EntityObjectRegister
|
||||
|
||||
@property
|
||||
def arrays(self):
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._register.keys())
|
||||
|
||||
def __init__(self):
|
||||
super(Entities, self).__init__()
|
||||
@ -124,23 +185,64 @@ class Entities(Register):
|
||||
def __iter__(self):
|
||||
return iter([x for sublist in self.values() for x in sublist])
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions):
|
||||
raise AttributeError()
|
||||
def register_item(self, other: dict):
|
||||
assert not any([key for key in other.keys() if key in self._names]), \
|
||||
"This group of entities has already been registered!"
|
||||
self._register.update(other)
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: Dict):
|
||||
return self.register_item(others)
|
||||
|
||||
|
||||
class FloorTiles(EntityRegister):
|
||||
_accepted_objects = Tile
|
||||
class WallTiles(EntityObjectRegister):
|
||||
_accepted_objects = Wall
|
||||
_light_blocking = True
|
||||
|
||||
def as_array(self):
|
||||
if not np.any(self._array):
|
||||
x, y = zip(*[x.pos for x in self])
|
||||
self._array[0, x, y] = self.encoding
|
||||
return self._array
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, individual_slices=False, is_blocking_light=self._light_blocking, **kwargs)
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL.value
|
||||
|
||||
@property
|
||||
def array(self):
|
||||
return self._array
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates):
|
||||
tiles = cls()
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(i, pos, name_is_identifier=True) for i, pos in enumerate(argwhere_coordinates)]
|
||||
[cls._accepted_objects(i, pos, name_is_identifier=True, is_blocking_light=cls._light_blocking)
|
||||
for i, pos in enumerate(argwhere_coordinates)]
|
||||
)
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
|
||||
_accepted_objects = Tile
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(self.__class__, self).__init__(*args, is_observable=False, **kwargs)
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.FREE_CELL.value
|
||||
|
||||
@property
|
||||
def occupied_tiles(self):
|
||||
tiles = [tile for tile in self if tile.is_occupied()]
|
||||
@ -153,8 +255,22 @@ class FloorTiles(EntityRegister):
|
||||
random.shuffle(tiles)
|
||||
return tiles
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
class Agents(EntityRegister):
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
z, x, y = range(len(self)), *zip(*[x.pos for x in self])
|
||||
self._array[z, x, y] = c.OCCUPIED_CELL.value
|
||||
if self.individual_slices:
|
||||
return self._array
|
||||
else:
|
||||
return self._array.sum(axis=0, keepdims=True)
|
||||
|
||||
_accepted_objects = Agent
|
||||
|
||||
@ -163,7 +279,17 @@ class Agents(EntityRegister):
|
||||
return [agent.pos for agent in self]
|
||||
|
||||
|
||||
class Doors(EntityRegister):
|
||||
class Doors(EntityObjectRegister):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = 0
|
||||
for door in self:
|
||||
self._array[0, door.x, door.y] = door.encoding
|
||||
return self._array
|
||||
|
||||
_accepted_objects = Door
|
||||
|
||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||
@ -221,47 +347,6 @@ class Actions(Register):
|
||||
return action == h.EnvActions.USE_DOOR.name
|
||||
|
||||
|
||||
class StateSlices(Register):
|
||||
|
||||
_accepted_objects = Slice
|
||||
@property
|
||||
def n_observable_slices(self):
|
||||
return len([x for x in self if x.is_observable])
|
||||
|
||||
|
||||
@property
|
||||
def AGENTSTARTIDX(self):
|
||||
if self._agent_start_idx:
|
||||
return self._agent_start_idx
|
||||
else:
|
||||
self._agent_start_idx = min([idx for idx, x in self.items() if c.AGENT.value in x.name])
|
||||
return self._agent_start_idx
|
||||
|
||||
def __init__(self):
|
||||
super(StateSlices, self).__init__()
|
||||
self._agent_start_idx = None
|
||||
|
||||
def _gather_occupation(self, excluded_slices):
|
||||
exclusion = excluded_slices or []
|
||||
assert isinstance(exclusion, (int, list))
|
||||
exclusion = exclusion if isinstance(exclusion, list) else [exclusion]
|
||||
|
||||
result = np.sum([x for i, x in self.items() if i not in exclusion], axis=0)
|
||||
return result
|
||||
|
||||
def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
|
||||
occupation = self._gather_occupation(excluded_slices)
|
||||
free_cells = np.argwhere(occupation == c.IS_FREE_CELL)
|
||||
np.random.shuffle(free_cells)
|
||||
return free_cells
|
||||
|
||||
def occupied_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
|
||||
occupation = self._gather_occupation(excluded_slices)
|
||||
occupied_cells = np.argwhere(occupation == c.IS_OCCUPIED_CELL.value)
|
||||
np.random.shuffle(occupied_cells)
|
||||
return occupied_cells
|
||||
|
||||
|
||||
class Zones(Register):
|
||||
|
||||
@property
|
||||
@ -279,9 +364,9 @@ class Zones(Register):
|
||||
self._accounting_zones = list()
|
||||
self._danger_zones = list()
|
||||
for symbol in np.unique(parsed_level):
|
||||
if symbol == h.WALL:
|
||||
if symbol == c.WALL.value:
|
||||
continue
|
||||
elif symbol == h.DANGER_ZONE:
|
||||
elif symbol == c.DANGER_ZONE.value:
|
||||
self + symbol
|
||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
||||
self._danger_zones.append(symbol)
|
||||
|
Reference in New Issue
Block a user