In Debugging

This commit is contained in:
Steffen Illium
2021-08-27 11:56:06 +02:00
parent 0fc4db193f
commit 4731f63ba6
7 changed files with 219 additions and 230 deletions

View File

@ -1,6 +1,5 @@
import random
from abc import ABC
from enum import Enum
from typing import List, Union, Dict
import numpy as np
@ -18,13 +17,8 @@ class Register:
def name(self):
return self.__class__.__name__
@property
def n(self):
return len(self)
def __init__(self, *args, **kwargs):
self._register = dict()
self._names = dict()
def __len__(self):
return len(self._register)
@ -36,9 +30,7 @@ class Register:
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
f'{self._accepted_objects}, ' \
f'but were {other.__class__}.,'
new_idx = len(self._register)
self._names.update({other.name: new_idx})
self._register.update({new_idx: other})
self._register.update({other.name: other})
return self
def register_additional_items(self, others: List[_accepted_objects]):
@ -56,31 +48,16 @@ class Register:
return self._register.items()
def __getitem__(self, item):
try:
return self._register[item]
except KeyError as e:
print('NO')
print(e)
raise
def by_name(self, item):
return self[self._names[item]]
def by_enum(self, enum_obj: Enum):
return self[self._names[enum_obj.name]]
if isinstance(item, int):
try:
return next(v for i, v in enumerate(self._register.values()) if i == item)
except StopIteration:
return None
return self._register[item]
def __repr__(self):
return f'{self.__class__.__name__}({self._register})'
def get_name(self, item):
return self._register[item].name
def get_idx_by_name(self, item):
return self._names[item]
def get_idx(self, enum_obj: Enum):
return self._names[enum_obj.name]
class ObjectRegister(Register):
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
@ -96,7 +73,7 @@ class ObjectRegister(Register):
self._array = np.zeros((1, *self._level_shape))
else:
if self.individual_slices:
self._array = np.concatenate((self._array, np.zeros(1, *self._level_shape)))
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
class EntityObjectRegister(ObjectRegister, ABC):
@ -107,8 +84,8 @@ class EntityObjectRegister(ObjectRegister, ABC):
@classmethod
def from_tiles(cls, tiles, *args, **kwargs):
# objects_name = cls._accepted_objects.__name__
entities = [cls._accepted_objects(i, tile, name_is_identifier=True, **kwargs)
for i, tile in enumerate(tiles)]
entities = [cls._accepted_objects(tile, **kwargs)
for tile in tiles]
register_obj = cls(*args)
register_obj.register_additional_items(entities)
return register_obj
@ -119,7 +96,7 @@ class EntityObjectRegister(ObjectRegister, ABC):
@property
def positions(self):
return list(self._tiles.keys())
return [x.pos for x in self]
@property
def tiles(self):
@ -128,25 +105,15 @@ class EntityObjectRegister(ObjectRegister, ABC):
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
super(EntityObjectRegister, self).__init__(*args, **kwargs)
self.can_be_shadowed = can_be_shadowed
self._tiles = dict()
self.is_blocking_light = is_blocking_light
self.is_observable = is_observable
def register_item(self, other):
super(EntityObjectRegister, self).register_item(other)
self._tiles[other.pos] = other
def register_additional_items(self, others):
for other in others:
self.register_item(other)
return self
def by_pos(self, pos):
if isinstance(pos, np.ndarray):
pos = tuple(pos)
try:
return self._tiles[pos]
except KeyError:
return next(item for item in self.values() if item.pos == pos)
except StopIteration:
return None
@ -159,12 +126,14 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
if isinstance(pos, np.ndarray):
pos = tuple(pos)
try:
return [x for x in self if x == pos][0]
except IndexError:
return next(x for x in self if x.pos == pos)
except StopIteration:
return None
def delete_item(self, item):
self
if not isinstance(item, str):
item = item.name
del self._register[item]
class Entities(Register):
@ -186,7 +155,7 @@ class Entities(Register):
return iter([x for sublist in self.values() for x in sublist])
def register_item(self, other: dict):
assert not any([key for key in other.keys() if key in self._names]), \
assert not any([key for key in other.keys() if key in self.keys()]), \
"This group of entities has already been registered!"
self._register.update(other)
return self
@ -206,7 +175,8 @@ class WallTiles(EntityObjectRegister):
return self._array
def __init__(self, *args, **kwargs):
super(WallTiles, self).__init__(*args, individual_slices=False, is_blocking_light=self._light_blocking, **kwargs)
super(WallTiles, self).__init__(*args, individual_slices=False,
is_blocking_light=self._light_blocking, **kwargs)
@property
def encoding(self):
@ -221,8 +191,8 @@ class WallTiles(EntityObjectRegister):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.register_additional_items(
[cls._accepted_objects(i, pos, name_is_identifier=True, is_blocking_light=cls._light_blocking)
for i, pos in enumerate(argwhere_coordinates)]
[cls._accepted_objects(pos, is_blocking_light=cls._light_blocking)
for pos in argwhere_coordinates]
)
return tiles
@ -237,7 +207,7 @@ class FloorTiles(WallTiles):
_light_blocking = False
def __init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, is_observable=False, **kwargs)
super(FloorTiles, self).__init__(*args, is_observable=False, **kwargs)
@property
def encoding(self):
@ -265,8 +235,11 @@ class Agents(MovingEntityObjectRegister):
def as_array(self):
self._array[:] = c.FREE_CELL.value
# noinspection PyTupleAssignmentBalance
z, x, y = range(len(self)), *zip(*[x.pos for x in self])
self._array[z, x, y] = c.OCCUPIED_CELL.value
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
if self.individual_slices:
self._array[z, x, y] += v
else:
self._array[0, x, y] += v
if self.individual_slices:
return self._array
else:
@ -293,9 +266,9 @@ class Doors(EntityObjectRegister):
_accepted_objects = Door
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
if found_doors := [door for door in self if position in door.access_area]:
return found_doors[0]
else:
try:
return next(door for door in self if position in door.access_area)
except StopIteration:
return None
def tick_doors(self):
@ -320,39 +293,23 @@ class Actions(Register):
super(Actions, self).__init__()
if self.allow_square_movement:
self.register_additional_items([self._accepted_objects(direction) for direction in h.ManhattanMoves])
self.register_additional_items([self._accepted_objects(enum_ident=direction)
for direction in h.ManhattanMoves])
if self.allow_diagonal_movement:
self.register_additional_items([self._accepted_objects(direction) for direction in h.DiagonalMoves])
self.register_additional_items([self._accepted_objects(enum_ident=direction)
for direction in h.DiagonalMoves])
self._movement_actions = self._register.copy()
if self.can_use_doors:
self.register_additional_items([self._accepted_objects(h.EnvActions.USE_DOOR)])
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.USE_DOOR)])
if self.allow_no_op:
self.register_additional_items([self._accepted_objects(h.EnvActions.NOOP)])
self.register_additional_items([self._accepted_objects(enum_ident=h.EnvActions.NOOP)])
def is_moving_action(self, action: Union[int]):
return action in self.movement_actions.values()
def is_no_op(self, action: Union[str, Action, int]):
if isinstance(action, int):
action = self[action]
if isinstance(action, Action):
action = action.name
return action == h.EnvActions.NOOP.name
def is_door_usage(self, action: Union[str, int]):
if isinstance(action, int):
action = self[action]
if isinstance(action, Action):
action = action.name
return action == h.EnvActions.USE_DOOR.name
class Zones(Register):
@property
def danger_zone(self):
return self._zone_slices[self.by_enum(c.DANGER_ZONE)]
@property
def accounting_zones(self):
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE.value]
@ -380,11 +337,5 @@ class Zones(Register):
def __getitem__(self, item):
return self._zone_slices[item]
def get_name(self, item):
return self._register[item]
def by_name(self, item):
return self[super(Zones, self).by_name(item)]
def register_additional_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')
raise AttributeError('You are not allowed to add additional Zones in runtime.')