Resolved some warnings and style issues

This commit is contained in:
Steffen Illium
2023-11-10 09:29:54 +01:00
parent a9462a8b6f
commit 6711a0976b
64 changed files with 331 additions and 361 deletions

View File

@@ -1,15 +1,14 @@
import abc
from collections import defaultdict
import numpy as np
from .object import _Object
from .object import Object
from .. import constants as c
from ...utils.results import ActionResult
from ...utils.utility_classes import RenderEntity
class Entity(_Object, abc.ABC):
class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property
@@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC):
def __init__(self, pos, bind_to=None, **kwargs):
super().__init__(**kwargs)
self._view_directory = c.VALUE_NO_POS
self._status = None
self.set_pos(pos)
self._pos = pos
self._last_pos = pos
if bind_to:
try:
@@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC):
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@abc.abstractmethod
def render(self):
return RenderEntity(self.__class__.__name__.lower(), self.pos)
@property
def obs_tag(self):
try:
@@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC):
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs)
collection.add_items(
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
return collection
def notify_del_entity(self, entity):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
def by_pos(self, pos: (int, int)):
pos = tuple(pos)
try:
return self.state.entities.pos_dict[pos]
except StopIteration:
pass
except ValueError:
pass

View File

@@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
import marl_factory_grid.utils.helpers as h
class _Object:
class Object:
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
_u_idx = defaultdict(lambda: 0)
@@ -50,15 +50,15 @@ class _Object:
print(f'Following kwargs were passed, but ignored: {kwargs}')
def __repr__(self):
name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except (AttributeError):
pass
return name
name = self.name
if self.bound_entity:
name = h.add_bound_name(name, self.bound_entity)
try:
if self.var_has_position:
name = h.add_pos_name(name, self)
except AttributeError:
pass
return name
def __eq__(self, other) -> bool:
return other == self.identifier
@@ -67,8 +67,8 @@ class _Object:
return hash(self.identifier)
def _identify_and_count_up(self):
idx = _Object._u_idx[self.__class__.__name__]
_Object._u_idx[self.__class__.__name__] += 1
idx = Object._u_idx[self.__class__.__name__]
Object._u_idx[self.__class__.__name__] += 1
return idx
def set_collection(self, collection):
@@ -98,79 +98,3 @@ class _Object:
def unbind(self):
self._bound_entity = None
# class EnvObject(_Object):
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
#
# _u_idx = defaultdict(lambda: 0)
#
# @property
# def obs_tag(self):
# try:
# return self._collection.name or self.name
# except AttributeError:
# return self.name
#
# @property
# def var_is_blocking_light(self):
# try:
# return self._collection.var_is_blocking_light or False
# except AttributeError:
# return False
#
# @property
# def var_can_be_bound(self):
# try:
# return self._collection.var_can_be_bound or False
# except AttributeError:
# return False
#
# @property
# def var_can_move(self):
# try:
# return self._collection.var_can_move or False
# except AttributeError:
# return False
#
# @property
# def var_is_blocking_pos(self):
# try:
# return self._collection.var_is_blocking_pos or False
# except AttributeError:
# return False
#
# @property
# def var_has_position(self):
# try:
# return self._collection.var_has_position or False
# except AttributeError:
# return False
#
# @property
# def var_can_collide(self):
# try:
# return self._collection.var_can_collide or False
# except AttributeError:
# return False
#
#
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
#
#
# def __init__(self, **kwargs):
# self._bound_entity = None
# super(EnvObject, self).__init__(**kwargs)
#
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@@ -1,6 +1,6 @@
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
##########################################################################
@@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
##########################################################################
class PlaceHolder(_Object):
class PlaceHolder(Object):
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
@@ -27,7 +27,7 @@ class PlaceHolder(_Object):
return self.__class__.__name__
class GlobalPosition(_Object):
class GlobalPosition(Object):
@property
def encoding(self):

View File

@@ -56,15 +56,18 @@ class Factory(gym.Env):
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities()
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
# Init for later usage:
self.state: Gamestate
self.map: LevelParser
self.obs_builder: OBSBuilder
# noinspection PyTypeChecker
self.state: Gamestate = None
# noinspection PyTypeChecker
self.obs_builder: OBSBuilder = None
# expensive - don't use; unless required !
self._renderer = None
# reset env to initial state, preparing env for new episode.
# returns tuple where the first dict contains initial observation for each agent in the env
@@ -74,7 +77,7 @@ class Factory(gym.Env):
return self.state.entities[item]
def reset(self) -> (dict, dict):
if hasattr(self, 'state'):
if self.state is not None:
for entity_group in self.state.entities:
try:
entity_group[0].reset_uid()
@@ -160,7 +163,7 @@ class Factory(gym.Env):
# Finalize
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
info = reward_info
info = dict(reward_info)
info.update(step_reward=sum(reward), step=self.state.curr_step)

View File

@@ -1,15 +1,15 @@
from typing import List, Tuple, Union, Dict
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
# noinspection PyProtectedMember
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.results import Result
class Collection(_Objects):
_entity = _Object # entity?
class Collection(Objects):
_entity = Object # entity?
symbol = None
@property
@@ -58,7 +58,7 @@ class Collection(_Objects):
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
if self.var_has_position:
if isinstance(coords_or_quantity, int):
if self.var_has_position and isinstance(coords_or_quantity, int):
if ignore_blocking or self._ignore_blocking:
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
else:
@@ -87,8 +87,8 @@ class Collection(_Objects):
raise ValueError(f'{self._entity.__name__} has no position!')
return c.VALID
def despawn(self, items: List[_Object]):
items = [items] if isinstance(items, _Object) else items
def despawn(self, items: List[Object]):
items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]

View File

@@ -3,12 +3,12 @@ from operator import itemgetter
from random import shuffle
from typing import Dict
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK
class Entities(_Objects):
_entity = _Objects
class Entities(Objects):
_entity = Objects
@staticmethod
def neighboring_positions(pos):
@@ -87,7 +87,7 @@ class Entities(_Objects):
def __delitem__(self, name):
assert_str = 'This group of entity does not exist in this collection!'
assert any([key for key in name.keys() if key in self.keys()]), assert_str
self[name]._observers.delete(self)
self[name].del_observer(self)
for entity in self[name]:
entity.del_observer(self)
return super(Entities, self).__delitem__(name)

View File

@@ -1,15 +1,15 @@
from collections import defaultdict
from typing import List
from typing import List, Iterator, Union
import numpy as np
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils import helpers as h
class _Objects:
_entity = _Object
class Objects:
_entity = Object
@property
def var_can_be_bound(self):
@@ -50,7 +50,7 @@ class _Objects:
def __len__(self):
return len(self._data)
def __iter__(self):
def __iter__(self) -> Iterator[Union[Object, None]]:
return iter(self.values())
def add_item(self, item: _entity):
@@ -130,13 +130,14 @@ class _Objects:
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def notify_del_entity(self, entity: _Object):
def notify_del_entity(self, entity: Object):
try:
# noinspection PyUnresolvedReferences
self.pos_dict[entity.pos].remove(entity)
except (AttributeError, ValueError, IndexError):
pass
def notify_add_entity(self, entity: _Object):
def notify_add_entity(self, entity: Object):
try:
if self not in entity.observers:
entity.add_observer(self)

View File

@@ -1,11 +1,11 @@
import abc
from random import shuffle
from typing import List, Collection, Union
from typing import List, Collection
from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import rewards as r, constants as c
class Rule(abc.ABC):
@@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule):
def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(lvl_map.level_shape)
gp.bind_to(agent)
gp = GlobalPosition(agent, lvl_map.level_shape)
state[c.GLOBALPOSITIONS].add_item(gp)
return []