Resolved some warnings and style issues

This commit is contained in:
Steffen Illium
2023-11-10 09:29:54 +01:00
parent a9462a8b6f
commit 6711a0976b
64 changed files with 331 additions and 361 deletions

View File

@ -1,15 +1,15 @@
from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
class Collection(_Objects):
_entity = _Object # entity?
class Collection(Objects):
_entity = Object # entity?
symbol = None
@property
@ -58,7 +58,7 @@ class Collection(_Objects):
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
if isinstance(coords_or_quantity, int):
if self.var_has_position and isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
@ -87,8 +87,8 @@ class Collection(_Objects):
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
def despawn(self, items: List[_Object]):
items = [items] if isinstance(items, _Object) else items
def despawn(self, items: List[Object]):
items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]

View File

@ -3,12 +3,12 @@ from operator import itemgetter
from random import shuffle
from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK
class Entities(_Objects):
_entity = _Objects
class Entities(Objects):
_entity = Objects
@staticmethod
def neighboring_positions(pos):
@ -87,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self)
self[name].del_observer(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)

View File

@ -1,15 +1,15 @@
from collections import defaultdict
from typing import List
from typing import List, Iterator, Union
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
class _Objects:
_entity = _Object
class Objects:
_entity = Object
@property
def var_can_be_bound(self):
@ -50,7 +50,7 @@ class _Objects:
def __len__(self):
return len(self._data)
def __iter__(self):
def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values())
def add_item(self, item: _entity):
@ -130,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def notify_del_entity(self, entity: _Object):
def notify_del_entity(self, entity: Object):
try:
# noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError):
pass
def notify_add_entity(self, entity: _Object):
def notify_add_entity(self, entity: Object):
try:
if self not in entity.observers:
entity.add_observer(self)