WIP: object, entity rework

This commit is contained in:
Chanumask
2023-10-19 17:25:31 +02:00
parent 8709b093b8
commit 8d6dcd70ae
18 changed files with 155 additions and 53 deletions

View File

@@ -1,21 +1,57 @@
import abc
from collections import defaultdict
from .object import Object
from .. import constants as c
from .object import EnvObject
from ...utils.render import RenderEntity
from ...utils.results import ActionResult
class Entity(EnvObject, abc.ABC):
class Entity(Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
_u_idx = defaultdict(lambda: 0)
@property
def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
# @property
# def var_has_position(self):
# return self.pos != c.VALUE_NO_POS
var_has_position: bool = True
@property
def var_has_position(self):
return self.pos != c.VALUE_NO_POS
def var_is_blocking_light(self):
try:
return self._collection.var_is_blocking_light or False
except AttributeError:
return False
# var_is_blocking_light: bool = True
@property
def var_can_move(self):
try:
return self._collection.var_can_move or False
except AttributeError:
return False
@property
def var_is_blocking_pos(self):
try:
return self._collection.var_is_blocking_pos or False
except AttributeError:
return False
@property
def var_can_collide(self):
try:
return self._collection.var_can_collide or False
except AttributeError:
return False
# var_can_collide: bool = True
@property
def x(self):
@@ -29,10 +65,6 @@ class Entity(EnvObject, abc.ABC):
def pos(self):
return self._pos
@property
def tile(self):
return self._tile # wall_n_floors funktionalität
# @property
# def last_tile(self):
# try:
@@ -71,7 +103,7 @@ class Entity(EnvObject, abc.ABC):
print(f'Objects of {self.__class__.__name__} can not be bound to other entities.')
exit()
def summarize_state(self) -> dict: # tile=str(self.tile.name)
def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y), can_collide=bool(self.var_can_collide))
@abc.abstractmethod
@@ -80,3 +112,43 @@ class Entity(EnvObject, abc.ABC):
def __repr__(self):
return super(Entity, self).__repr__() + f'(@{self.pos})'
@property
def obs_tag(self):
try:
return self._collection.name or self.name
except AttributeError:
return self.name
@property
def encoding(self):
return c.VALUE_OCCUPIED_CELL
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs)
collection.add_items(
[cls._entity(tuple(pos), **entity_kwargs if entity_kwargs is not None else {}) for pos in positions])
return collection
def notify_del_entity(self, entity):
try:
self.pos_dict[entity.pos].remove(entity)
except (ValueError, AttributeError):
pass
def by_pos(self, pos: (int, int)):
pos = tuple(pos)
try:
return self.state.entities.pos_dict[pos]
# return next(e for e in self if e.pos == pos)
except StopIteration:
pass
except ValueError:
print()

View File

@@ -12,6 +12,19 @@ class Object:
def __bool__(self):
return True
@property
def var_has_position(self): # brauchen wir das hier jetzt?
try:
return self.pos != c.VALUE_NO_POS or False
except AttributeError:
return False
@property
def var_can_be_bound(self):
try:
return self._collection.var_can_be_bound or False
except AttributeError:
return False
@property
def observers(self):
@@ -70,6 +83,14 @@ class Object:
def summarize_state(self):
return dict()
def bind(self, entity):
# noinspection PyAttributeOutsideInit
self._bound_entity = entity
return c.VALID
def belongs_to_entity(self, entity):
return self._bound_entity == entity
class EnvObject(Object):

View File

@@ -30,7 +30,7 @@ class PlaceHolder(Object):
return "PlaceHolder"
class GlobalPosition(BoundEntityMixin, EnvObject):
class GlobalPosition(Object):
@property
def encoding(self):

View File

@@ -3,6 +3,7 @@ from typing import List
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import EnvObject
from marl_factory_grid.utils.render import RenderEntity
from marl_factory_grid.utils import helpers as h
@@ -109,7 +110,7 @@ class Floor(EnvObject):
return None
class Wall(Floor):
class Wall(Entity):
@property
def var_can_collide(self):

View File

@@ -1,9 +1,9 @@
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.env_objects import Collection
from marl_factory_grid.environment.groups.mixins import PositionMixin
class Agents(PositionMixin, EnvObjects):
class Agents(PositionMixin, Collection):
_entity = Agent
is_blocking_light = False
can_move = True

View File

@@ -2,26 +2,27 @@ from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.environment.entity.object import EnvObject
class EnvObjects(Objects):
class Collection(Objects):
_entity = EnvObject
var_is_blocking_light: bool = False
var_can_collide: bool = False
var_has_position: bool = False
var_can_move: bool = False
var_can_be_bound: bool = False
var_has_position: bool = False # alles was posmixin hat true
var_has_bound = False # batteries, globalpos, inventories true
var_can_be_bound: bool = False # == ^
@property
def encodings(self):
return [x.encoding for x in self]
def __init__(self, size, *args, **kwargs):
super(EnvObjects, self).__init__(*args, **kwargs)
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def add_item(self, item: EnvObject):
assert self.var_has_position or (len(self) <= self.size)
super(EnvObjects, self).add_item(item)
super(Collection, self).add_item(item)
return self
def delete_env_object(self, env_object: EnvObject):
@@ -29,3 +30,19 @@ class EnvObjects(Objects):
def delete_env_object_by_name(self, name):
del self[name]
@property
def obs_pairs(self):
return [(x.name, x) for x in self]
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None

View File

@@ -19,15 +19,6 @@ class PositionMixin:
def render(self):
return [y for y in [x.render() for x in self] if y is not None]
# @classmethod
# def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
# collection = cls(*args, **kwargs)
# entities = [cls._entity(tile, str_ident=i,
# **entity_kwargs if entity_kwargs is not None else {})
# for i, tile in enumerate(tiles)]
# collection.add_items(entities)
# return collection
@classmethod
def from_coordinates(cls, positions: [(int, int)], *args, entity_kwargs=None, **kwargs, ):
collection = cls(*args, **kwargs)

View File

@@ -3,7 +3,7 @@ from typing import List, Union
import numpy as np
from marl_factory_grid.environment.entity.util import GlobalPosition
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.env_objects import Collection
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
@@ -11,7 +11,7 @@ from marl_factory_grid.utils import helpers as h
from marl_factory_grid.environment import constants as c
class Combined(PositionMixin, EnvObjects):
class Combined(PositionMixin, Collection):
@property
def name(self):
@@ -35,7 +35,7 @@ class Combined(PositionMixin, EnvObjects):
return [(name, None) for name in self.names]
class GlobalPositions(HasBoundMixin, EnvObjects):
class GlobalPositions(Collection):
_entity = GlobalPosition
is_blocking_light = False,

View File

@@ -2,12 +2,12 @@ import random
from typing import List, Tuple
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.env_objects import Collection
from marl_factory_grid.environment.groups.mixins import PositionMixin
from marl_factory_grid.environment.entity.wall_floor import Wall, Floor
class Walls(PositionMixin, EnvObjects):
class Walls(PositionMixin, Collection):
_entity = Wall
symbol = c.SYMBOL_WALL