mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-20 05:56:07 +01:00
Resolved some warnings and style issues
This commit is contained in:
@@ -1,15 +1,14 @@
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .object import _Object
|
||||
from .object import Object
|
||||
from .. import constants as c
|
||||
from ...utils.results import ActionResult
|
||||
from ...utils.utility_classes import RenderEntity
|
||||
|
||||
|
||||
class Entity(_Object, abc.ABC):
|
||||
class Entity(Object, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
@@ -96,8 +95,9 @@ class Entity(_Object, abc.ABC):
|
||||
|
||||
def __init__(self, pos, bind_to=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._view_directory = c.VALUE_NO_POS
|
||||
self._status = None
|
||||
self.set_pos(pos)
|
||||
self._pos = pos
|
||||
self._last_pos = pos
|
||||
if bind_to:
|
||||
try:
|
||||
@@ -113,10 +113,6 @@ class Entity(_Object, abc.ABC):
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
return RenderEntity(self.__class__.__name__.lower(), self.pos)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
@@ -133,25 +129,3 @@ class Entity(_Object, abc.ABC):
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
|
||||
@classmethod
|
||||
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
|
||||
collection = cls(*args, **kwargs)
|
||||
collection.add_items(
|
||||
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
|
||||
return collection
|
||||
|
||||
def notify_del_entity(self, entity):
|
||||
try:
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
pos = tuple(pos)
|
||||
try:
|
||||
return self.state.entities.pos_dict[pos]
|
||||
except StopIteration:
|
||||
pass
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
|
||||
import marl_factory_grid.utils.helpers as h
|
||||
|
||||
|
||||
class _Object:
|
||||
class Object:
|
||||
"""Generell Objects for Organisation and Maintanance such as Actions etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
@@ -50,15 +50,15 @@ class _Object:
|
||||
print(f'Following kwargs were passed, but ignored: {kwargs}')
|
||||
|
||||
def __repr__(self):
|
||||
name = self.name
|
||||
if self.bound_entity:
|
||||
name = h.add_bound_name(name, self.bound_entity)
|
||||
try:
|
||||
if self.var_has_position:
|
||||
name = h.add_pos_name(name, self)
|
||||
except (AttributeError):
|
||||
pass
|
||||
return name
|
||||
name = self.name
|
||||
if self.bound_entity:
|
||||
name = h.add_bound_name(name, self.bound_entity)
|
||||
try:
|
||||
if self.var_has_position:
|
||||
name = h.add_pos_name(name, self)
|
||||
except AttributeError:
|
||||
pass
|
||||
return name
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return other == self.identifier
|
||||
@@ -67,8 +67,8 @@ class _Object:
|
||||
return hash(self.identifier)
|
||||
|
||||
def _identify_and_count_up(self):
|
||||
idx = _Object._u_idx[self.__class__.__name__]
|
||||
_Object._u_idx[self.__class__.__name__] += 1
|
||||
idx = Object._u_idx[self.__class__.__name__]
|
||||
Object._u_idx[self.__class__.__name__] += 1
|
||||
return idx
|
||||
|
||||
def set_collection(self, collection):
|
||||
@@ -98,79 +98,3 @@ class _Object:
|
||||
|
||||
def unbind(self):
|
||||
self._bound_entity = None
|
||||
|
||||
|
||||
# class EnvObject(_Object):
|
||||
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
#
|
||||
# _u_idx = defaultdict(lambda: 0)
|
||||
#
|
||||
# @property
|
||||
# def obs_tag(self):
|
||||
# try:
|
||||
# return self._collection.name or self.name
|
||||
# except AttributeError:
|
||||
# return self.name
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_light(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_light or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_be_bound(self):
|
||||
# try:
|
||||
# return self._collection.var_can_be_bound or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_move(self):
|
||||
# try:
|
||||
# return self._collection.var_can_move or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_pos(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_pos or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_has_position(self):
|
||||
# try:
|
||||
# return self._collection.var_has_position or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_collide(self):
|
||||
# try:
|
||||
# return self._collection.var_can_collide or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# @property
|
||||
# def encoding(self):
|
||||
# return c.VALUE_OCCUPIED_CELL
|
||||
#
|
||||
#
|
||||
# def __init__(self, **kwargs):
|
||||
# self._bound_entity = None
|
||||
# super(EnvObject, self).__init__(**kwargs)
|
||||
#
|
||||
#
|
||||
# def change_parent_collection(self, other_collection):
|
||||
# other_collection.add_item(self)
|
||||
# self._collection.delete_env_object(self)
|
||||
# self._collection = other_collection
|
||||
# return self._collection == other_collection
|
||||
#
|
||||
#
|
||||
# def summarize_state(self):
|
||||
# return dict(name=str(self.name))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
|
||||
##########################################################################
|
||||
@@ -8,7 +8,7 @@ from marl_factory_grid.environment.entity.object import _Object
|
||||
##########################################################################
|
||||
|
||||
|
||||
class PlaceHolder(_Object):
|
||||
class PlaceHolder(Object):
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -27,7 +27,7 @@ class PlaceHolder(_Object):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class GlobalPosition(_Object):
|
||||
class GlobalPosition(Object):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
|
||||
@@ -56,15 +56,18 @@ class Factory(gym.Env):
|
||||
self.level_filepath = Path(custom_level_path)
|
||||
else:
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use; unless required !
|
||||
|
||||
parsed_entities = self.conf.load_entities()
|
||||
self.map = LevelParser(self.level_filepath, parsed_entities, self.conf.pomdp_r)
|
||||
|
||||
# Init for later usage:
|
||||
self.state: Gamestate
|
||||
self.map: LevelParser
|
||||
self.obs_builder: OBSBuilder
|
||||
# noinspection PyTypeChecker
|
||||
self.state: Gamestate = None
|
||||
# noinspection PyTypeChecker
|
||||
self.obs_builder: OBSBuilder = None
|
||||
|
||||
# expensive - don't use; unless required !
|
||||
self._renderer = None
|
||||
|
||||
# reset env to initial state, preparing env for new episode.
|
||||
# returns tuple where the first dict contains initial observation for each agent in the env
|
||||
@@ -74,7 +77,7 @@ class Factory(gym.Env):
|
||||
return self.state.entities[item]
|
||||
|
||||
def reset(self) -> (dict, dict):
|
||||
if hasattr(self, 'state'):
|
||||
if self.state is not None:
|
||||
for entity_group in self.state.entities:
|
||||
try:
|
||||
entity_group[0].reset_uid()
|
||||
@@ -160,7 +163,7 @@ class Factory(gym.Env):
|
||||
# Finalize
|
||||
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
||||
|
||||
info = reward_info
|
||||
info = dict(reward_info)
|
||||
|
||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import List, Tuple, Union, Dict
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
# noinspection PyProtectedMember
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
class Collection(_Objects):
|
||||
_entity = _Object # entity?
|
||||
class Collection(Objects):
|
||||
_entity = Object # entity?
|
||||
symbol = None
|
||||
|
||||
@property
|
||||
@@ -58,7 +58,7 @@ class Collection(_Objects):
|
||||
def trigger_spawn(self, state, *entity_args, coords_or_quantity=None, ignore_blocking=False, **entity_kwargs):
|
||||
coords_or_quantity = coords_or_quantity if coords_or_quantity else self._coords_or_quantity
|
||||
if self.var_has_position:
|
||||
if isinstance(coords_or_quantity, int):
|
||||
if self.var_has_position and isinstance(coords_or_quantity, int):
|
||||
if ignore_blocking or self._ignore_blocking:
|
||||
coords_or_quantity = state.entities.floorlist[:coords_or_quantity]
|
||||
else:
|
||||
@@ -87,8 +87,8 @@ class Collection(_Objects):
|
||||
raise ValueError(f'{self._entity.__name__} has no position!')
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[_Object]):
|
||||
items = [items] if isinstance(items, _Object) else items
|
||||
def despawn(self, items: List[Object]):
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ from operator import itemgetter
|
||||
from random import shuffle
|
||||
from typing import Dict
|
||||
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.utils.helpers import POS_MASK
|
||||
|
||||
|
||||
class Entities(_Objects):
|
||||
_entity = _Objects
|
||||
class Entities(Objects):
|
||||
_entity = Objects
|
||||
|
||||
@staticmethod
|
||||
def neighboring_positions(pos):
|
||||
@@ -87,7 +87,7 @@ class Entities(_Objects):
|
||||
def __delitem__(self, name):
|
||||
assert_str = 'This group of entity does not exist in this collection!'
|
||||
assert any([key for key in name.keys() if key in self.keys()]), assert_str
|
||||
self[name]._observers.delete(self)
|
||||
self[name].del_observer(self)
|
||||
for entity in self[name]:
|
||||
entity.del_observer(self)
|
||||
return super(Entities, self).__delitem__(name)
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import List, Iterator, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
class _Objects:
|
||||
_entity = _Object
|
||||
class Objects:
|
||||
_entity = Object
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
@@ -50,7 +50,7 @@ class _Objects:
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[Union[Object, None]]:
|
||||
return iter(self.values())
|
||||
|
||||
def add_item(self, item: _entity):
|
||||
@@ -130,13 +130,14 @@ class _Objects:
|
||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def notify_del_entity(self, entity: _Object):
|
||||
def notify_del_entity(self, entity: Object):
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.pos_dict[entity.pos].remove(entity)
|
||||
except (AttributeError, ValueError, IndexError):
|
||||
pass
|
||||
|
||||
def notify_add_entity(self, entity: _Object):
|
||||
def notify_add_entity(self, entity: Object):
|
||||
try:
|
||||
if self not in entity.observers:
|
||||
entity.add_observer(self)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import abc
|
||||
from random import shuffle
|
||||
from typing import List, Collection, Union
|
||||
from typing import List, Collection
|
||||
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import rewards as r, constants as c
|
||||
|
||||
|
||||
class Rule(abc.ABC):
|
||||
@@ -118,8 +118,7 @@ class AssignGlobalPositions(Rule):
|
||||
def on_init(self, state, lvl_map):
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(lvl_map.level_shape)
|
||||
gp.bind_to(agent)
|
||||
gp = GlobalPosition(agent, lvl_map.level_shape)
|
||||
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||
return []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user