mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-26 05:01:36 +02:00
Resolved some warnings and style issues
This commit is contained in:
@ -1,15 +1,15 @@
|
||||
from typing import List, Tuple, Union, Dict
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
# noinspection PyProtectedMember
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class Collection(_Objects):
|
||||
_entity = _Object # entity?
|
||||
class Collection(Objects):
|
||||
_entity = Object # entity?
|
||||
symbol = None
|
||||
|
||||
@property
|
||||
@ -58,7 +58,7 @@ class Collection(_Objects):
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||
if self.var_has_position:
|
||||
if isinstance(coords_or_quantity, int):
|
||||
if self.var_has_position and isinstance(coords_or_quantity, int):
|
||||
if ignore_blocking or self._ignore_blocking:
|
||||
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
|
||||
else:
|
||||
@ -87,8 +87,8 @@ class Collection(_Objects):
|
||||
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[_Object]):
|
||||
items = [items] if isinstance(items, _Object) else items
|
||||
def despawn(self, items: List[Object]):
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
@ -3,12 +3,12 @@ from operator import itemgetter
|
||||
from random import shuffle
|
||||
from typing import Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.utils.helpers import POS_MASK
|
||||
|
||||
|
||||
class Entities(_Objects):
|
||||
_entity = _Objects
|
||||
class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
@staticmethod
|
||||
def neighboring_positions(pos):
|
||||
@ -87,7 +87,7 @@ class Entities(_Objects):
|
||||
def __delitem__(self, name):
|
||||
assert_str = 'This group of entity does not exist in this collection!'
|
||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||
self[name]._observers.delete(self)
|
||||
self[name].del_observer(self)
|
||||
for entity in self[name]:
|
||||
entity.del_observer(self)
|
||||
return super(Entities, self).__delitem__(name)
|
||||
|
@ -1,15 +1,15 @@
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import List, Iterator, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class _Objects:
|
||||
_entity = _Object
|
||||
class Objects:
|
||||
_entity = Object
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
@ -50,7 +50,7 @@ class _Objects:
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[Union[Object, None]]:
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
@ -130,13 +130,14 @@ class _Objects:
|
||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def notify_del_entity(self, entity: _Object):
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (AttributeError, ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: _Object):
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
|
Reference in New Issue
Block a user