mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
new observation properties for testing of technical limitations
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import numbers
|
||||
import random
|
||||
from abc import ABC
|
||||
from typing import List, Union, Dict
|
||||
@ -91,21 +92,18 @@ class EntityObjectRegister(ObjectRegister, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
register_obj = cls(*args, **kwargs)
|
||||
try:
|
||||
del kwargs['individual_slices']
|
||||
except KeyError:
|
||||
pass
|
||||
entities = [cls._accepted_objects(tile, str_ident=i, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, str_ident=i, **entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, **kwargs):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], *args, **kwargs)
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
return cls.from_tiles([tiles.by_pos(position) for position in positions], *args, entity_kwargs=entity_kwargs,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def positions(self):
|
||||
@ -166,10 +164,15 @@ class PlaceHolders(MovingEntityObjectRegister):
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
if isinstance(self.fill_value, int):
|
||||
if isinstance(self.fill_value, numbers.Number):
|
||||
self._array[:] = self.fill_value
|
||||
elif self.fill_value == "normal":
|
||||
self._array = np.random.normal(size=self._array.shape)
|
||||
elif isinstance(self.fill_value, str):
|
||||
if self.fill_value.lower() in ['normal', 'n']:
|
||||
self._array = np.random.normal(size=self._array.shape)
|
||||
else:
|
||||
raise ValueError('Choose one of: ["normal", "N"]')
|
||||
else:
|
||||
raise TypeError('Objects of type "str" or "number" is required here.')
|
||||
|
||||
if self.individual_slices:
|
||||
return self._array
|
||||
@ -183,10 +186,12 @@ class Entities(Register):
|
||||
|
||||
@property
|
||||
def observable_arrays(self):
|
||||
# FIXME: Find a better name
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable}
|
||||
|
||||
@property
|
||||
def obs_arrays(self):
|
||||
# FIXME: Find a better name
|
||||
return {key: val.as_array() for key, val in self.items() if val.is_observable and not val.hide_from_obs_builder}
|
||||
|
||||
@property
|
||||
@ -208,6 +213,10 @@ class Entities(Register):
|
||||
def register_additional_items(self, others: Dict):
|
||||
return self.register_item(others)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
return found_entities
|
||||
|
||||
|
||||
class WallTiles(EntityObjectRegister):
|
||||
_accepted_objects = Wall
|
||||
@ -289,6 +298,10 @@ class Agents(MovingEntityObjectRegister):
|
||||
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, hide_from_obs_builder=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hide_from_obs_builder = hide_from_obs_builder
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
|
Reference in New Issue
Block a user