Rework for performance

This commit is contained in:
Steffen Illium
2022-01-10 15:54:22 +01:00
parent 78bf19f7f4
commit 435056f373
10 changed files with 525 additions and 469 deletions

View File

@ -85,19 +85,27 @@ class EnvObjectRegister(ObjectRegister):
def encodings(self):
return [x.encoding for x in self]
def __init__(self, obs_shape: (int, int), *args, individual_slices: bool = False, **kwargs):
def __init__(self, obs_shape: (int, int), *args,
individual_slices: bool = False,
is_blocking_light: bool = False,
can_collide: bool = False,
can_be_shadowed: bool = True, **kwargs):
super(EnvObjectRegister, self).__init__(*args, **kwargs)
self._shape = obs_shape
self._array = None
self._individual_slices = individual_slices
self._lazy_eval_transforms = []
self.is_blocking_light = is_blocking_light
self.can_be_shadowed = can_be_shadowed
self.can_collide = can_collide
def register_item(self, other: EnvObject):
super(EnvObjectRegister, self).register_item(other)
if self._array is None:
self._array = np.zeros((1, *self._shape))
if self._individual_slices:
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
else:
if self._individual_slices:
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
self.notify_change_to_value(other)
def as_array(self):
@ -179,14 +187,9 @@ class EntityRegister(EnvObjectRegister, ABC):
def tiles(self):
return [entity.tile for entity in self]
def __init__(self, level_shape, *args,
is_blocking_light: bool = False,
can_be_shadowed: bool = True,
**kwargs):
def __init__(self, level_shape, *args, **kwargs):
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
self._lazy_eval_transforms = []
self.can_be_shadowed = can_be_shadowed
self.is_blocking_light = is_blocking_light
def __delitem__(self, name):
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
@ -220,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
return None
class BoundRegisterMixin(EnvObjectRegister, ABC):
class BoundEnvObjRegister(EnvObjectRegister, ABC):
def __init__(self, entity_to_be_bound, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -229,6 +232,21 @@ class BoundRegisterMixin(EnvObjectRegister, ABC):
def belongs_to_entity(self, entity):
return self._bound_entity == entity
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except StopIteration:
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except StopIteration:
return None
def as_array_by_entity(self, entity):
return self._array[self.idx_by_entity(entity)]
class MovingEntityObjectRegister(EntityRegister, ABC):
@ -255,6 +273,7 @@ class GlobalPositions(EnvObjectRegister):
is_blocking_light = False
can_be_shadowed = False
can_collide = False
def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
@ -360,7 +379,6 @@ class Entities(ObjectRegister):
class WallTiles(EntityRegister):
_accepted_objects = Wall
_light_blocking = True
def as_array(self):
if not np.any(self._array):
@ -371,9 +389,10 @@ class WallTiles(EntityRegister):
self._array[0, x, y] = self._value
return self._array
def __init__(self, *args, **kwargs):
super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False,
**kwargs)
def __init__(self, *args, is_blocking_light=True, **kwargs):
super(WallTiles, self).__init__(*args, individual_slices=False,
can_collide=True,
is_blocking_light=is_blocking_light, **kwargs)
self._value = c.OCCUPIED_CELL
@classmethod
@ -381,7 +400,7 @@ class WallTiles(EntityRegister):
tiles = cls(*args, **kwargs)
# noinspection PyTypeChecker
tiles.register_additional_items(
[cls._accepted_objects(pos, tiles, is_blocking_light=cls._light_blocking)
[cls._accepted_objects(pos, tiles)
for pos in argwhere_coordinates]
)
return tiles
@ -399,10 +418,9 @@ class WallTiles(EntityRegister):
class FloorTiles(WallTiles):
_accepted_objects = Tile
_light_blocking = False
def __init__(self, *args, **kwargs):
super(FloorTiles, self).__init__(*args, **kwargs)
def __init__(self, *args, is_blocking_light=False, **kwargs):
super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
self._value = c.FREE_CELL
@property
@ -430,7 +448,7 @@ class Agents(MovingEntityObjectRegister):
_accepted_objects = Agent
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(*args, can_collide=True, **kwargs)
@property
def positions(self):
@ -446,7 +464,7 @@ class Agents(MovingEntityObjectRegister):
class Doors(EntityRegister):
def __init__(self, *args, **kwargs):
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs)
_accepted_objects = Door