This commit is contained in:
Chanumask
2023-10-27 13:08:01 +02:00
parent ef2fdd5d28
commit b13dff925b
13 changed files with 141 additions and 99 deletions

View File

@@ -12,8 +12,6 @@ from ...utils.utility_classes import RenderEntity
class Entity(_Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
_u_idx = defaultdict(lambda: 0)
@property
def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
@@ -29,7 +27,6 @@ class Entity(_Object, abc.ABC):
except AttributeError:
return False
@property
def var_can_move(self):
try:
@@ -51,7 +48,6 @@ class Entity(_Object, abc.ABC):
except AttributeError:
return False
@property
def x(self):
return self.pos[0]
@@ -87,7 +83,7 @@ class Entity(_Object, abc.ABC):
if valid := state.check_move_validity(self, next_pos):
for observer in self.observers:
observer.notify_del_entity(self)
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
self._pos = next_pos
for observer in self.observers:
observer.notify_add_entity(self)

View File

@@ -14,10 +14,7 @@ class _Object:
@property
def var_has_position(self):
try:
return self.pos != c.VALUE_NO_POS or False
except AttributeError:
return False
return False
@property
def var_can_be_bound(self):
@@ -36,6 +33,17 @@ class _Object:
return f'{self.__class__.__name__}[{self._str_ident}]'
return f'{self.__class__.__name__}#{self.u_int}'
# @property
# def name(self):
# name = f"{self.__class__.__name__}"
# if self.bound_entity:
# name += f"[{self.bound_entity.name}]"
# if self._str_ident is not None:
# name += f"({self._str_ident})"
# else:
# name += f"(#{self.u_int})"
# return name
@property
def identifier(self):
if self._str_ident is not None:
@@ -48,6 +56,7 @@ class _Object:
return True
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
self._bound_entity = None
self._observers = []
self._str_ident = str_ident
self.u_int = self._identify_and_count_up()
@@ -91,73 +100,83 @@ class _Object:
def belongs_to_entity(self, entity):
return self._bound_entity == entity
class EnvObject(_Object):
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
_u_idx = defaultdict(lambda: 0)
@property
def obs_tag(self):
try:
return self._collection.name or self.name
except AttributeError:
return self.name
def bound_entity(self):
return self._bound_entity
@property
def var_is_blocking_light(self):
try:
return self._collection.var_is_blocking_light or False
except AttributeError:
return False
def bind_to(self, entity):
self._bound_entity = entity
@property
def var_can_be_bound(self):
try:
return self._collection.var_can_be_bound or False
except AttributeError:
return False
def unbind(self):
self._bound_entity = None
@property
def var_can_move(self):
try:
return self._collection.var_can_move or False
except AttributeError:
return False
@property
def var_is_blocking_pos(self):
try:
return self._collection.var_is_blocking_pos or False
except AttributeError:
return False
@property
def var_has_position(self):
try:
return self._collection.var_has_position or False
except AttributeError:
return False
@property
def var_can_collide(self):
try:
return self._collection.var_can_collide or False
except AttributeError:
return False
@property
def encoding(self):
return c.VALUE_OCCUPIED_CELL
def __init__(self, **kwargs):
super(EnvObject, self).__init__(**kwargs)
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
def summarize_state(self):
return dict(name=str(self.name))
# class EnvObject(_Object):
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
#
# _u_idx = defaultdict(lambda: 0)
#
# @property
# def obs_tag(self):
# try:
# return self._collection.name or self.name
# except AttributeError:
# return self.name
#
# @property
# def var_is_blocking_light(self):
# try:
# return self._collection.var_is_blocking_light or False
# except AttributeError:
# return False
#
# @property
# def var_can_be_bound(self):
# try:
# return self._collection.var_can_be_bound or False
# except AttributeError:
# return False
#
# @property
# def var_can_move(self):
# try:
# return self._collection.var_can_move or False
# except AttributeError:
# return False
#
# @property
# def var_is_blocking_pos(self):
# try:
# return self._collection.var_is_blocking_pos or False
# except AttributeError:
# return False
#
# @property
# def var_has_position(self):
# try:
# return self._collection.var_has_position or False
# except AttributeError:
# return False
#
# @property
# def var_can_collide(self):
# try:
# return self._collection.var_can_collide or False
# except AttributeError:
# return False
#
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
#
# def __init__(self, **kwargs):
# super(EnvObject, self).__init__(**kwargs)
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@@ -1,9 +1,6 @@
import math
import numpy as np
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
from marl_factory_grid.environment.entity.object import _Object, EnvObject
from marl_factory_grid.environment.entity.object import _Object
##########################################################################

View File

@@ -2,11 +2,11 @@ from typing import List, Tuple
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.entity.object import EnvObject
from marl_factory_grid.environment.entity.object import _Object
class Collection(_Objects):
_entity = EnvObject # entity? object? objects?
_entity = _Object # entity?
@property
def var_is_blocking_light(self):
@@ -22,13 +22,13 @@ class Collection(_Objects):
@property
def var_has_position(self):
return False # alles was posmixin hat true
return False
# @property
# def var_has_bound(self):
# return False # batteries, globalpos, inventories true
@property
def var_has_bound(self):
return False # batteries, globalpos, inventories true
@property # beide bounds hier? inventory can be bound
def var_can_be_bound(self):
return False
@@ -40,12 +40,12 @@ class Collection(_Objects):
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def add_item(self, item: EnvObject):
def add_item(self, item: Entity):
assert self.var_has_position or (len(self) <= self.size)
super(Collection, self).add_item(item)
return self
def delete_env_object(self, env_object: EnvObject):
def delete_env_object(self, env_object):
del self[env_object.name]
def delete_env_object_by_name(self, name):

View File

@@ -36,7 +36,6 @@ class Entities(_Objects):
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property
def empty_positions(self):
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions)

View File

@@ -122,7 +122,7 @@ class _Objects:
raise TypeError
def __repr__(self):
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def spawn(self, n: int):
@@ -169,3 +169,15 @@ class _Objects:
# FIXME PROTOBUFF
# return [e.summarize_state() for e in self]
return [e.summarize_state() for e in self]
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None

View File

@@ -49,7 +49,7 @@ class SpawnAgents(Rule):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions[:len(agent_conf)]
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()