Everything is an object now

This commit is contained in:
Steffen Illium
2021-08-26 17:47:15 +02:00
parent bd0a8090ab
commit 0fc4db193f
7 changed files with 613 additions and 447 deletions

View File

@ -1,10 +1,11 @@
import random
from abc import ABC
from enum import Enum
from typing import List, Union
from typing import List, Union, Dict
import numpy as np
from environments.factory.base.objects import Entity, Tile, Agent, Door, Slice, Action
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall
from environments.utility_classes import MovementProperties
from environments import helpers as h
from environments.helpers import Constants as c
@ -13,10 +14,6 @@ from environments.helpers import Constants as c
class Register:
_accepted_objects = Entity
@classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles):
return cls.from_tiles([tiles.by_pos(position) for position in positions])
@property
def name(self):
return self.__class__.__name__
@ -25,7 +22,7 @@ class Register:
def n(self):
return len(self)
def __init__(self):
def __init__(self, *args, **kwargs):
self._register = dict()
self._names = dict()
@ -35,17 +32,18 @@ class Register:
def __iter__(self):
return iter(self.values())
def __add__(self, other: _accepted_objects):
def register_item(self, other: _accepted_objects):
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,'
self._names.update({other.name: len(self._register)})
self._register.update({len(self._register): other})
new_idx = len(self._register)
self._names.update({other.name: new_idx})
self._register.update({new_idx: other})
return self
def register_additional_items(self, others: List[_accepted_objects]):
for other in others:
self + other
self.register_item(other)
return self
def keys(self):
@ -60,8 +58,9 @@ class Register:
def __getitem__(self, item):
try:
return self._register[item]
except KeyError:
except KeyError as e:
print('NO')
print(e)
raise
def by_name(self, item):
@ -82,29 +81,66 @@ class Register:
def get_idx(self, enum_obj: Enum):
return self._names[enum_obj.name]
class ObjectRegister(Register):
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
super(ObjectRegister, self).__init__(*args, **kwargs)
self.is_per_agent = is_per_agent
self.individual_slices = individual_slices
self._level_shape = level_shape
self._array = None
def register_item(self, other):
super(ObjectRegister, self).register_item(other)
if self._array is None:
self._array = np.zeros((1, *self._level_shape))
else:
if self.individual_slices:
self._array = np.concatenate((self._array, np.zeros(1, *self._level_shape)))
class EntityObjectRegister(ObjectRegister, ABC):
def as_array(self):
raise NotImplementedError
@classmethod
def from_tiles(cls, tiles, **kwargs):
def from_tiles(cls, tiles, *args, **kwargs):
# objects_name = cls._accepted_objects.__name__
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs) for i, tile in enumerate(tiles)]
registered_obj = cls()
registered_obj.register_additional_items(entities)
return registered_obj
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs)
for i, tile in enumerate(tiles)]
register_obj = cls(*args)
register_obj.register_additional_items(entities)
return register_obj
class EntityRegister(Register):
@classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, **kwargs):
return cls.from_tiles([tiles.by_pos(position) for position in positions], *args, **kwargs)
@property
def positions(self):
return [agent.pos for agent in self]
return list(self._tiles.keys())
def __init__(self):
super(EntityRegister, self).__init__()
@property
def tiles(self):
return [entity.tile for entity in self]
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
super(EntityObjectRegister, self).__init__(*args, **kwargs)
self.can_be_shadowed = can_be_shadowed
self._tiles = dict()
self.is_blocking_light = is_blocking_light
self.is_observable = is_observable
def __add__(self, other):
super(EntityRegister, self).__add__(other)
def register_item(self, other):
super(EntityObjectRegister, self).register_item(other)
self._tiles[other.pos] = other
def register_additional_items(self, others):
for other in others:
self.register_item(other)
return self
def by_pos(self, pos):
if isinstance(pos, np.ndarray):
pos = tuple(pos)
@ -114,9 +150,34 @@ class EntityRegister(Register):
return None
class MovingEntityObjectRegister(EntityObjectRegister, ABC):
def __init__(self, *args, **kwargs):
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
def by_pos(self, pos):
if isinstance(pos, np.ndarray):
pos = tuple(pos)
try:
return [x for x in self if x == pos][0]
except IndexError:
return None
def delete_item(self, item):
self
class Entities(Register):
_accepted_objects = Register
_accepted_objects = EntityObjectRegister
@property
def arrays(self):
return {key: val.as_array() for key, val in self.items() if val.is_observable}
@property
def names(self):
return list(self._register.keys())
def __init__(self):
super(Entities, self).__init__()
@ -124,23 +185,64 @@ class Entities(Register):
def __iter__(self):
return iter([x for sublist in self.values() for x in sublist])
@classmethod
def from_argwhere_coordinates(cls, positions):
raise AttributeError()
def register_item(self, other: dict):
assert not any([key for key in other.keys() if key in self._names]), \
"This group of entities has already been registered!"
self._register.update(other)
return self
def register_additional_items(self, others: Dict):
return self.register_item(others)
class FloorTiles(EntityRegister):
_accepted_objects = Tile
class WallTiles(EntityObjectRegister):
_accepted_objects = Wall
_light_blocking = True
def as_array(self):
if not np.any(self._array):
x, y = zip(*[x.pos for x in self])
self._array[0, x, y] = self.encoding
return self._array
def __init__(self, *args, **kwargs):
super(WallTiles, self).__init__(*args, individual_slices=False, is_blocking_light=self._light_blocking, **kwargs)
@property
def encoding(self):
return c.OCCUPIED_CELL.value
@property
def array(self):
return self._array
@classmethod
def from_argwhere_coordinates(cls, argwhere_coordinates):
tiles = cls()
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.register_additional_items(
[cls._accepted_objects(i, pos, name_is_identifier=True) for i, pos in enumerate(argwhere_coordinates)]
[cls._accepted_objects(i, pos, name_is_identifier=True, is_blocking_light=cls._light_blocking)
for i, pos in enumerate(argwhere_coordinates)]
)
return tiles
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
class FloorTiles(WallTiles):
_accepted_objects = Tile
_light_blocking = False
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, is_observable=False, **kwargs)
@property
def encoding(self):
return c.FREE_CELL.value
@property
def occupied_tiles(self):
tiles = [tile for tile in self if tile.is_occupied()]
@ -153,8 +255,22 @@ class FloorTiles(EntityRegister):
random.shuffle(tiles)
return tiles
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
class Agents(EntityRegister):
class Agents(MovingEntityObjectRegister):
def as_array(self):
self._array[:] = c.FREE_CELL.value
# noinspection PyTupleAssignmentBalance
z, x, y = range(len(self)), *zip(*[x.pos for x in self])
self._array[z, x, y] = c.OCCUPIED_CELL.value
if self.individual_slices:
return self._array
else:
return self._array.sum(axis=0, keepdims=True)
_accepted_objects = Agent
@ -163,7 +279,17 @@ class Agents(EntityRegister):
return [agent.pos for agent in self]
class Doors(EntityRegister):
class Doors(EntityObjectRegister):
def __init__(self, *args, **kwargs):
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
def as_array(self):
self._array[:] = 0
for door in self:
self._array[0, door.x, door.y] = door.encoding
return self._array
_accepted_objects = Door
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
@ -221,47 +347,6 @@ class Actions(Register):
return action == h.EnvActions.USE_DOOR.name
class StateSlices(Register):
_accepted_objects = Slice
@property
def n_observable_slices(self):
return len([x for x in self if x.is_observable])
@property
def AGENTSTARTIDX(self):
if self._agent_start_idx:
return self._agent_start_idx
else:
self._agent_start_idx = min([idx for idx, x in self.items() if c.AGENT.value in x.name])
return self._agent_start_idx
def __init__(self):
super(StateSlices, self).__init__()
self._agent_start_idx = None
def _gather_occupation(self, excluded_slices):
exclusion = excluded_slices or []
assert isinstance(exclusion, (int, list))
exclusion = exclusion if isinstance(exclusion, list) else [exclusion]
result = np.sum([x for i, x in self.items() if i not in exclusion], axis=0)
return result
def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
occupation = self._gather_occupation(excluded_slices)
free_cells = np.argwhere(occupation == c.IS_FREE_CELL)
np.random.shuffle(free_cells)
return free_cells
def occupied_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
occupation = self._gather_occupation(excluded_slices)
occupied_cells = np.argwhere(occupation == c.IS_OCCUPIED_CELL.value)
np.random.shuffle(occupied_cells)
return occupied_cells
class Zones(Register):
@property
@ -279,9 +364,9 @@ class Zones(Register):
self._accounting_zones = list()
self._danger_zones = list()
for symbol in np.unique(parsed_level):
if symbol == h.WALL:
if symbol == c.WALL.value:
continue
elif symbol == h.DANGER_ZONE:
elif symbol == c.DANGER_ZONE.value:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._danger_zones.append(symbol)