mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
Rework of Observations and Entity Differentiation, lazy obs build by notification
This commit is contained in:
@ -1,18 +1,23 @@
|
||||
import numbers
|
||||
import random
|
||||
from abc import ABC
|
||||
from typing import List, Union, Dict
|
||||
from typing import List, Union, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, Object, PlaceHolder
|
||||
from environments.factory.base.objects import Entity, Tile, Agent, Door, Action, Wall, PlaceHolder, GlobalPosition, \
|
||||
Object, EnvObject
|
||||
from environments.utility_classes import MovementProperties
|
||||
from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
##########################################################################
|
||||
# ##################### Base Register Definition ####################### #
|
||||
##########################################################################
|
||||
|
||||
class Register:
|
||||
_accepted_objects = Entity
|
||||
|
||||
class ObjectRegister:
|
||||
_accepted_objects = Object
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -48,6 +53,12 @@ class Register:
|
||||
def items(self):
|
||||
return self._register.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
||||
except (StopIteration, AssertionError):
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
@ -65,39 +76,66 @@ class Register:
|
||||
return f'{self.__class__.__name__}({self._register})'
|
||||
|
||||
|
||||
class ObjectRegister(Register):
|
||||
class EnvObjectRegister(ObjectRegister):
|
||||
|
||||
hide_from_obs_builder = False
|
||||
_accepted_objects = EnvObject
|
||||
|
||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
||||
super(ObjectRegister, self).__init__(*args, **kwargs)
|
||||
self.is_per_agent = is_per_agent
|
||||
self.individual_slices = individual_slices
|
||||
self._level_shape = level_shape
|
||||
def __init__(self, obs_shape: (int, int), *args, **kwargs):
|
||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||
self._shape = obs_shape
|
||||
self._array = None
|
||||
self.hide_from_obs_builder = False
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def register_item(self, other):
|
||||
super(ObjectRegister, self).register_item(other)
|
||||
def register_item(self, other: EnvObject):
|
||||
super(EnvObjectRegister, self).register_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
else:
|
||||
if self.individual_slices:
|
||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
self.notify_change_to_value(other)
|
||||
|
||||
def as_array(self):
|
||||
if self._lazy_eval_transforms:
|
||||
idxs, values = zip(*self._lazy_eval_transforms)
|
||||
# nuumpy put repects the ordering so that
|
||||
np.put(self._array, idxs, values)
|
||||
self._lazy_eval_transforms = []
|
||||
return self._array
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||
|
||||
def notify_change_to_free(self, env_object: EnvObject):
|
||||
self._array_change_notifyer(env_object, value=c.FREE_CELL.value)
|
||||
|
||||
class EntityObjectRegister(ObjectRegister, ABC):
|
||||
def notify_change_to_value(self, env_object: EnvObject):
|
||||
self._array_change_notifyer(env_object)
|
||||
|
||||
def as_array(self):
|
||||
raise NotImplementedError
|
||||
def _array_change_notifyer(self, env_object: EnvObject, value=None):
|
||||
pos = self._get_index(env_object)
|
||||
value = value if value is not None else env_object.encoding
|
||||
self._lazy_eval_transforms.append((pos, value))
|
||||
|
||||
def __delitem__(self, name):
|
||||
self.notify_change_to_free(self._register[name])
|
||||
del self._register[name]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
del self[name]
|
||||
|
||||
|
||||
class EntityRegister(EnvObjectRegister, ABC):
|
||||
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
register_obj = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, str_ident=i, **entity_kwargs if entity_kwargs is not None else {})
|
||||
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
@ -115,86 +153,172 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
||||
def tiles(self):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, *args, is_blocking_light=False, is_observable=True, can_be_shadowed=True, **kwargs):
|
||||
super(EntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.is_blocking_light = is_blocking_light
|
||||
self.is_observable = is_observable
|
||||
@property
|
||||
def encodings(self):
|
||||
return [x.encoding for x in self]
|
||||
|
||||
def by_pos(self, pos):
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = tuple(pos)
|
||||
def __init__(self, level_shape, *args,
|
||||
is_blocking_light: bool = False,
|
||||
can_be_shadowed: bool = True,
|
||||
individual_slices: bool = False, **kwargs):
|
||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||
self._lazy_eval_transforms = []
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.individual_slices = individual_slices
|
||||
self.is_blocking_light = is_blocking_light
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super(EntityRegister, self).__delitem__(name)
|
||||
if self.individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
|
||||
def as_array(self):
|
||||
if self._lazy_eval_transforms:
|
||||
idxs, values = zip(*self._lazy_eval_transforms)
|
||||
# numpy put repects the ordering so that
|
||||
# Todo: Export the index building in a seperate function
|
||||
np.put(self._array, [np.ravel_multi_index(idx, self._array.shape) for idx in idxs], values)
|
||||
self._lazy_eval_transforms = []
|
||||
return self._array
|
||||
|
||||
def _array_change_notifyer(self, entity, pos=None, value=None):
|
||||
# Todo: Export the contruction in a seperate function
|
||||
pos = pos if pos is not None else entity.pos
|
||||
value = value if value is not None else entity.encoding
|
||||
x, y = pos
|
||||
if self.individual_slices:
|
||||
idx = (self._get_index(entity), x, y)
|
||||
else:
|
||||
idx = (0, x, y)
|
||||
self._lazy_eval_transforms.append((idx, value))
|
||||
|
||||
def by_pos(self, pos: Tuple[int, int]):
|
||||
try:
|
||||
return next(item for item in self.values() if item.pos == pos)
|
||||
return next(item for item in self if item.pos == tuple(pos))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
||||
class BoundRegisterMixin(EnvObjectRegister, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_entities_to_bind(self, entitites):
|
||||
def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]],
|
||||
*args, object_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
if isinstance(values, (str, numbers.Number)):
|
||||
values = [values]
|
||||
register_obj = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||
**object_kwargs if object_kwargs is not None else {})
|
||||
for i, value in enumerate(values)]
|
||||
register_obj.register_additional_items(objects)
|
||||
return register_obj
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
|
||||
def by_pos(self, pos):
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = tuple(pos)
|
||||
def notify_change_to_value(self, entity):
|
||||
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
||||
if entity.last_pos != c.NO_POS.value:
|
||||
try:
|
||||
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL.value)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ################# Objects and Entity Registers ####################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class GlobalPositions(EnvObjectRegister):
|
||||
_accepted_objects = GlobalPosition
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
hide_from_obs_builder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GlobalPositions, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
|
||||
def as_array(self):
|
||||
# Todo make this lazy?
|
||||
return np.stack([gp.as_array() for inv_idx, gp in enumerate(self)])
|
||||
|
||||
def spawn_GlobalPositionObjects(self, obs_shape, agents):
|
||||
global_positions = [self._accepted_objects(self._shape, obs_shape, agent)
|
||||
for _, agent in enumerate(agents)]
|
||||
# noinspection PyTypeChecker
|
||||
self.register_additional_items(global_positions)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next(x for x in self if x.pos == pos)
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx = next(i for i, entity in enumerate(self) if entity.name == name)
|
||||
del self._register[name]
|
||||
if self.individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
|
||||
def delete_entity(self, item):
|
||||
self.delete_entity_by_name(item.name)
|
||||
|
||||
def delete_entity_by_name(self, name):
|
||||
del self[name]
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
class PlaceHolders(MovingEntityObjectRegister):
|
||||
|
||||
class PlaceHolders(EnvObjectRegister):
|
||||
_accepted_objects = PlaceHolder
|
||||
|
||||
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert not 'individual_slices' in kwargs, 'Keyword - "individual_slices": "True" and must not be altered'
|
||||
kwargs.update(individual_slices=False)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fill_value = fill_value
|
||||
|
||||
@classmethod
|
||||
def from_values(cls, values: Union[str, numbers.Number, List[Union[str, numbers.Number]]],
|
||||
*args, object_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
if isinstance(values, (str, numbers.Number)):
|
||||
values = [values]
|
||||
register_obj = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||
**object_kwargs if object_kwargs is not None else {})
|
||||
for i, value in enumerate(values)]
|
||||
register_obj.register_additional_items(objects)
|
||||
return register_obj
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
if isinstance(self.fill_value, numbers.Number):
|
||||
self._array[:] = self.fill_value
|
||||
elif isinstance(self.fill_value, str):
|
||||
if self.fill_value.lower() in ['normal', 'n']:
|
||||
self._array = np.random.normal(size=self._array.shape)
|
||||
for idx, placeholder in enumerate(self):
|
||||
if isinstance(placeholder.encoding, numbers.Number):
|
||||
self._array[idx][:] = placeholder.fill_value
|
||||
elif isinstance(placeholder.fill_value, str):
|
||||
if placeholder.fill_value.lower() in ['normal', 'n']:
|
||||
self._array[:] = np.random.normal(size=self._array.shape)
|
||||
else:
|
||||
raise ValueError('Choose one of: ["normal", "N"]')
|
||||
else:
|
||||
raise ValueError('Choose one of: ["normal", "N"]')
|
||||
else:
|
||||
raise TypeError('Objects of type "str" or "number" is required here.')
|
||||
raise TypeError('Objects of type "str" or "number" is required here.')
|
||||
|
||||
if self.individual_slices:
|
||||
return self._array
|
||||
else:
|
||||
return self._array[None, 0]
|
||||
return self._array
|
||||
|
||||
|
||||
class Entities(Register):
|
||||
|
||||
_accepted_objects = EntityObjectRegister
|
||||
class Entities(ObjectRegister):
|
||||
_accepted_objects = EntityRegister
|
||||
|
||||
@property
|
||||
def observable_arrays(self):
|
||||
# FIXME: Find a better name
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
||||
def arrays(self):
|
||||
return {key: val.as_array() for key, val in self.items()}
|
||||
|
||||
@property
|
||||
def obs_arrays(self):
|
||||
# FIXME: Find a better name
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
||||
return {key: val.as_array() for key, val in self.items() if not val.hide_from_obs_builder}
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
@ -220,34 +344,34 @@ class Entities(Register):
|
||||
return found_entities
|
||||
|
||||
|
||||
class WallTiles(EntityObjectRegister):
|
||||
class WallTiles(EntityRegister):
|
||||
_accepted_objects = Wall
|
||||
_light_blocking = True
|
||||
hide_from_obs_builder = True
|
||||
|
||||
def as_array(self):
|
||||
if not np.any(self._array):
|
||||
# Which is Faster?
|
||||
# indices = [x.pos for x in self]
|
||||
# np.put(self._array, [np.ravel_multi_index((0, *x), self._array.shape) for x in indices], self.encodings)
|
||||
x, y = zip(*[x.pos for x in self])
|
||||
self._array[0, x, y] = self.encoding
|
||||
return self._array
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(WallTiles, self).__init__(*args, individual_slices=False,
|
||||
is_blocking_light=self._light_blocking, **kwargs)
|
||||
super(WallTiles, self).__init__(*args, is_blocking_light=self._light_blocking, individual_slices=False,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL.value
|
||||
|
||||
@property
|
||||
def array(self):
|
||||
return self._array
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
[cls._accepted_objects(pos, is_blocking_light=cls._light_blocking)
|
||||
[cls._accepted_objects(pos, tiles, is_blocking_light=cls._light_blocking)
|
||||
for pos in argwhere_coordinates]
|
||||
)
|
||||
return tiles
|
||||
@ -264,12 +388,11 @@ class WallTiles(EntityObjectRegister):
|
||||
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
|
||||
_accepted_objects = Tile
|
||||
_light_blocking = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FloorTiles, self).__init__(*args, is_observable=False, **kwargs)
|
||||
super(FloorTiles, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
@ -297,22 +420,21 @@ class FloorTiles(WallTiles):
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, hide_from_obs_builder=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hide_from_obs_builder = hide_from_obs_builder
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
||||
if self.individual_slices:
|
||||
self._array[z, x, y] += v
|
||||
else:
|
||||
self._array[0, x, y] += v
|
||||
@DeprecationWarning
|
||||
def Xas_array(self):
|
||||
# Super Safe Version
|
||||
# self._array[:] = c.FREE_CELL.value
|
||||
indices = list(zip(range(len(self)), *zip(*[x.last_pos for x in self])))
|
||||
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], c.FREE_CELL.value)
|
||||
indices = list(zip(range(len(self)), *zip(*[x.pos for x in self])))
|
||||
np.put(self._array, [np.ravel_multi_index(x, self._array.shape) for x in indices], self.encodings)
|
||||
|
||||
if self.individual_slices:
|
||||
return self._array
|
||||
else:
|
||||
@ -329,17 +451,11 @@ class Agents(MovingEntityObjectRegister):
|
||||
self._register[agent.name] = agent
|
||||
|
||||
|
||||
class Doors(EntityObjectRegister):
|
||||
class Doors(EntityRegister):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Doors, self).__init__(*args, is_blocking_light=True, **kwargs)
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = 0
|
||||
for door in self:
|
||||
self._array[0, door.x, door.y] = door.encoding
|
||||
return self._array
|
||||
|
||||
_accepted_objects = Door
|
||||
|
||||
def get_near_position(self, position: (int, int)) -> Union[None, Door]:
|
||||
@ -353,8 +469,7 @@ class Doors(EntityObjectRegister):
|
||||
door.tick()
|
||||
|
||||
|
||||
class Actions(Register):
|
||||
|
||||
class Actions(ObjectRegister):
|
||||
_accepted_objects = Action
|
||||
|
||||
@property
|
||||
@ -385,7 +500,7 @@ class Actions(Register):
|
||||
return action in self.movement_actions.values()
|
||||
|
||||
|
||||
class Zones(Register):
|
||||
class Zones(ObjectRegister):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
|
Reference in New Issue
Block a user