mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 09:31:35 +02:00
recoder adaption
This commit is contained in:
@ -81,8 +81,8 @@ class ObjectRegister(Register):
|
||||
if self.individual_slices:
|
||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
||||
|
||||
def summarize_states(self):
|
||||
return [val.summarize_state() for val in self.values()]
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||
|
||||
|
||||
class EntityObjectRegister(ObjectRegister, ABC):
|
||||
@ -156,23 +156,25 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
||||
del self[name]
|
||||
|
||||
|
||||
class PlaceHolderRegister(MovingEntityObjectRegister):
|
||||
class PlaceHolders(MovingEntityObjectRegister):
|
||||
|
||||
_accepted_objects = PlaceHolder
|
||||
|
||||
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fill_value = fill_value
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
||||
if self.individual_slices:
|
||||
self._array[z, x, y] += v
|
||||
else:
|
||||
self._array[0, x, y] += v
|
||||
if isinstance(self.fill_value, int):
|
||||
self._array[:] = self.fill_value
|
||||
elif self.fill_value == "normal":
|
||||
self._array = np.random.normal(size=self._array.shape)
|
||||
|
||||
if self.individual_slices:
|
||||
return self._array
|
||||
else:
|
||||
return self._array.sum(axis=0, keepdims=True)
|
||||
return self._array[None, 0]
|
||||
|
||||
|
||||
class Entities(Register):
|
||||
@ -243,6 +245,12 @@ class WallTiles(EntityObjectRegister):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
if n_steps == h.STEPS_START:
|
||||
return super(WallTiles, self).summarize_states(n_steps=n_steps)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
|
||||
@ -272,6 +280,10 @@ class FloorTiles(WallTiles):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# Do not summarize
|
||||
return {}
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
|
||||
|
Reference in New Issue
Block a user