fix mismatching signatures of spawn

This commit is contained in:
Chanumask
2023-10-27 17:46:13 +02:00
parent dd5737e3ff
commit 115a79e930
11 changed files with 38 additions and 42 deletions

View File

@@ -1,8 +1,9 @@
from typing import List, Tuple
from typing import List, Tuple, Union
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.entity.object import _Object
import marl_factory_grid.environment.constants as c
class Collection(_Objects):
@@ -40,6 +41,18 @@ class Collection(_Objects):
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def spawn(self, coords_or_quantity: Union[int, List[Tuple[(int, int)]]], *entity_args): # woihn mit den args
if isinstance(coords_or_quantity, int):
self.add_items([self._entity() for _ in range(coords_or_quantity)])
else:
self.add_items([self._entity(pos) for pos in coords_or_quantity])
return c.VALID
def despawn(self, items: List[_Object]):
items = [items] if isinstance(items, _Object) else items
for item in items:
del self[item]
def add_item(self, item: Entity):
assert self.var_has_position or (len(self) <= self.size)
super(Collection, self).add_item(item)
@@ -67,9 +80,6 @@ class Collection(_Objects):
except (StopIteration, AttributeError):
return None
def spawn(self, coords: List[Tuple[(int, int)]]):
self.add_items([self._entity(pos) for pos in coords])
def render(self):
return [y for y in [x.render() for x in self] if y is not None]