mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Rework for performance
This commit is contained in:
@ -85,19 +85,27 @@ class EnvObjectRegister(ObjectRegister):
|
||||
def encodings(self):
|
||||
return [x.encoding for x in self]
|
||||
|
||||
def __init__(self, obs_shape: (int, int), *args, individual_slices: bool = False, **kwargs):
|
||||
def __init__(self, obs_shape: (int, int), *args,
|
||||
individual_slices: bool = False,
|
||||
is_blocking_light: bool = False,
|
||||
can_collide: bool = False,
|
||||
can_be_shadowed: bool = True, **kwargs):
|
||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||
self._shape = obs_shape
|
||||
self._array = None
|
||||
self._individual_slices = individual_slices
|
||||
self._lazy_eval_transforms = []
|
||||
self.is_blocking_light = is_blocking_light
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.can_collide = can_collide
|
||||
|
||||
def register_item(self, other: EnvObject):
|
||||
super(EnvObjectRegister, self).register_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
if self._individual_slices:
|
||||
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
|
||||
else:
|
||||
if self._individual_slices:
|
||||
self._array = np.vstack((self._array, np.zeros((1, *self._shape))))
|
||||
self.notify_change_to_value(other)
|
||||
|
||||
def as_array(self):
|
||||
@ -179,14 +187,9 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, level_shape, *args,
|
||||
is_blocking_light: bool = False,
|
||||
can_be_shadowed: bool = True,
|
||||
**kwargs):
|
||||
def __init__(self, level_shape, *args, **kwargs):
|
||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||
self._lazy_eval_transforms = []
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.is_blocking_light = is_blocking_light
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
@ -220,7 +223,7 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return None
|
||||
|
||||
|
||||
class BoundRegisterMixin(EnvObjectRegister, ABC):
|
||||
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
|
||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -229,6 +232,21 @@ class BoundRegisterMixin(EnvObjectRegister, ABC):
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
|
||||
@ -255,6 +273,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
can_collide = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
@ -360,7 +379,6 @@ class Entities(ObjectRegister):
|
||||
|
||||
class WallTiles(EntityRegister):
|
||||
_accepted_objects = Wall
|
||||
_light_blocking = True
|
||||
|
||||
def as_array(self):
|
||||
if not np.any(self._array):
|
||||
@ -371,9 +389,10 @@ class WallTiles(EntityRegister):
|
||||
self._array[0, x, y] = self._value
|
||||
return self._array
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False,
|
||||
**kwargs)
|
||||
def __init__(self, *args, is_blocking_light=True, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
||||
can_collide=True,
|
||||
is_blocking_light=is_blocking_light, **kwargs)
|
||||
self._value = c.OCCUPIED_CELL
|
||||
|
||||
@classmethod
|
||||
@ -381,7 +400,7 @@ class WallTiles(EntityRegister):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(pos, tiles, is_blocking_light=cls._light_blocking)
|
||||
[cls._accepted_objects(pos, tiles)
|
||||
for pos in argwhere_coordinates]
|
||||
)
|
||||
return tiles
|
||||
@ -399,10 +418,9 @@ class WallTiles(EntityRegister):
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
_accepted_objects = Tile
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FloorTiles, self).__init__(*args, **kwargs)
|
||||
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||
super(FloorTiles, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||
self._value = c.FREE_CELL
|
||||
|
||||
@property
|
||||
@ -430,7 +448,7 @@ class Agents(MovingEntityObjectRegister):
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, can_collide=True, **kwargs)
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
@ -446,7 +464,7 @@ class Agents(MovingEntityObjectRegister):
|
||||
class Doors(EntityRegister):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
||||
super(Doors, self).__init__(*args, is_blocking_light=True, can_collide=True, **kwargs)
|
||||
|
||||
_accepted_objects = Door
|
||||
|
||||
|
Reference in New Issue
Block a user