2021-08-23 09:51:35 +02:00

301 lines
9.1 KiB
Python

import random
from enum import Enum
from typing import List, Union
import numpy as np
from environments.factory.base.objects import Entity, Tile, Agent, Door, Slice, Action
from environments.utility_classes import MovementProperties
from environments import helpers as h
from environments.helpers import Constants as c
class Register:
_accepted_objects = Entity
@classmethod
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles):
return cls.from_tiles([tiles.by_pos(position) for position in positions])
@property
def name(self):
return self.__class__.__name__
@property
def n(self):
return len(self)
def __init__(self):
self._register = dict()
self._names = dict()
def __len__(self):
return len(self._register)
def __iter__(self):
return iter(self.values())
def __add__(self, other: _accepted_objects):
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,'
self._names.update({other.name: len(self._register)})
self._register.update({len(self._register): other})
return self
def register_additional_items(self, others: List[_accepted_objects]):
for other in others:
self + other
return self
def keys(self):
return self._register.keys()
def values(self):
return self._register.values()
def items(self):
return self._register.items()
def __getitem__(self, item):
try:
return self._register[item]
except KeyError:
print('NO')
raise
def by_name(self, item):
return self[self._names[item]]
def by_enum(self, enum_obj: Enum):
return self[self._names[enum_obj.name]]
def __repr__(self):
return f'{self.__class__.__name__}({self._register})'
def get_name(self, item):
return self._register[item].name
def get_idx_by_name(self, item):
return self._names[item]
def get_idx(self, enum_obj: Enum):
return self._names[enum_obj.name]
@classmethod
def from_tiles(cls, tiles, **kwargs):
# objects_name = cls._accepted_objects.__name__
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs) for i, tile in enumerate(tiles)]
registered_obj = cls()
registered_obj.register_additional_items(entities)
return registered_obj
class EntityRegister(Register):
def __init__(self):
super(EntityRegister, self).__init__()
self._tiles = dict()
def __add__(self, other):
super(EntityRegister, self).__add__(other)
self._tiles[other.pos] = other
def by_pos(self, pos):
if isinstance(pos, np.ndarray):
pos = tuple(pos)
try:
return self._tiles[pos]
except KeyError:
return None
class Entities(Register):
_accepted_objects = Register
def __init__(self):
super(Entities, self).__init__()
def __iter__(self):
return iter([x for sublist in self.values() for x in sublist])
@classmethod
def from_argwhere_coordinates(cls, positions):
raise AttributeError()
class FloorTiles(EntityRegister):
_accepted_objects = Tile
@classmethod
def from_argwhere_coordinates(cls, argwhere_coordinates):
tiles = cls()
# noinspection PyTypeChecker
tiles.register_additional_items(
[cls._accepted_objects(i, pos, name_is_identifier=True) for i, pos in enumerate(argwhere_coordinates)]
)
return tiles
@property
def occupied_tiles(self):
tiles = [tile for tile in self if tile.is_occupied()]
random.shuffle(tiles)
return tiles
@property
def empty_tiles(self) -> List[Tile]:
tiles = [tile for tile in self if tile.is_empty()]
random.shuffle(tiles)
return tiles
class Agents(Register):
_accepted_objects = Agent
@property
def positions(self):
return [agent.pos for agent in self]
class Doors(EntityRegister):
_accepted_objects = Door
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
if found_doors := [door for door in self if position in door.access_area]:
return found_doors[0]
else:
return None
def tick_doors(self):
for door in self:
door.tick()
class Actions(Register):
_accepted_objects = Action
@property
def movement_actions(self):
return self._movement_actions
# noinspection PyTypeChecker
def __init__(self, movement_properties: MovementProperties, can_use_doors=False):
self.allow_no_op = movement_properties.allow_no_op
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement
self.can_use_doors = can_use_doors
super(Actions, self).__init__()
if self.allow_square_movement:
self.register_additional_items([self._accepted_objects(direction) for direction in h.ManhattanMoves])
if self.allow_diagonal_movement:
self.register_additional_items([self._accepted_objects(direction) for direction in h.DiagonalMoves])
self._movement_actions = self._register.copy()
if self.can_use_doors:
self.register_additional_items([self._accepted_objects(h.EnvActions.USE_DOOR)])
if self.allow_no_op:
self.register_additional_items([self._accepted_objects(h.EnvActions.NOOP)])
def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values()
def is_no_op(self, action: Union[str, Action, int]):
if isinstance(action, int):
action = self[action]
if isinstance(action, Action):
action = action.name
return action == h.EnvActions.NOOP.name
def is_door_usage(self, action: Union[str, int]):
if isinstance(action, int):
action = self[action]
if isinstance(action, Action):
action = action.name
return action == h.EnvActions.USE_DOOR.name
class StateSlices(Register):
_accepted_objects = Slice
@property
def n_observable_slices(self):
return len([x for x in self if x.is_observable])
@property
def AGENTSTARTIDX(self):
if self._agent_start_idx:
return self._agent_start_idx
else:
self._agent_start_idx = min([idx for idx, x in self.items() if c.AGENT.value in x.name])
return self._agent_start_idx
def __init__(self):
super(StateSlices, self).__init__()
self._agent_start_idx = None
def _gather_occupation(self, excluded_slices):
exclusion = excluded_slices or []
assert isinstance(exclusion, (int, list))
exclusion = exclusion if isinstance(exclusion, list) else [exclusion]
result = np.sum([x for i, x in self.items() if i not in exclusion], axis=0)
return result
def free_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
occupation = self._gather_occupation(excluded_slices)
free_cells = np.argwhere(occupation == c.IS_FREE_CELL)
np.random.shuffle(free_cells)
return free_cells
def occupied_cells(self, excluded_slices: Union[None, List[int], int] = None) -> np.array:
occupation = self._gather_occupation(excluded_slices)
occupied_cells = np.argwhere(occupation == c.IS_OCCUPIED_CELL.value)
np.random.shuffle(occupied_cells)
return occupied_cells
class Zones(Register):
@property
def danger_zone(self):
return self._zone_slices[self.by_enum(c.DANGER_ZONE)]
@property
def accounting_zones(self):
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
def __init__(self, parsed_level):
raise NotImplementedError('This needs a Rework')
super(Zones, self).__init__()
slices = list()
self._accounting_zones = list()
self._danger_zones = list()
for symbol in np.unique(parsed_level):
if symbol == h.WALL:
continue
elif symbol == h.DANGER_ZONE:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._danger_zones.append(symbol)
else:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._accounting_zones.append(symbol)
self._zone_slices = np.stack(slices)
def __getitem__(self, item):
return self._zone_slices[item]
def get_name(self, item):
return self._register[item]
def by_name(self, item):
return self[super(Zones, self).by_name(item)]
def register_additional_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')