This commit is contained in:
Chanumask
2023-10-27 13:08:01 +02:00
parent ef2fdd5d28
commit b13dff925b
13 changed files with 141 additions and 99 deletions

View File

@ -2,11 +2,11 @@ from typing import List, Tuple
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.entity.object import EnvObject
from marl_factory_grid.environment.entity.object import _Object
class Collection(_Objects):
_entity = EnvObject # entity? object? objects?
_entity = _Object # entity?
@property
def var_is_blocking_light(self):
@ -22,13 +22,13 @@ class Collection(_Objects):
@property
def var_has_position(self):
return False # alles was posmixin hat true
return False
# @property
# def var_has_bound(self):
# return False # batteries, globalpos, inventories true
@property
def var_has_bound(self):
return False # batteries, globalpos, inventories true
@property # beide bounds hier? inventory can be bound
def var_can_be_bound(self):
return False
@ -40,12 +40,12 @@ class Collection(_Objects):
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def add_item(self, item: EnvObject):
def add_item(self, item: Entity):
assert self.var_has_position or (len(self) <= self.size)
super(Collection, self).add_item(item)
return self
def delete_env_object(self, env_object: EnvObject):
def delete_env_object(self, env_object):
del self[env_object.name]
def delete_env_object_by_name(self, name):

View File

@ -36,7 +36,6 @@ class Entities(_Objects):
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property
def empty_positions(self):
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions)

View File

@ -122,7 +122,7 @@ class _Objects:
raise TypeError
def __repr__(self):
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def spawn(self, n: int):
@ -169,3 +169,15 @@ class _Objects:
# FIXME PROTOBUFF
# return [e.summarize_state() for e in self]
return [e.summarize_state() for e in self]
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None