mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-12-20 05:56:07 +01:00
bugfixes
This commit is contained in:
@@ -12,8 +12,6 @@ from ...utils.utility_classes import RenderEntity
|
||||
class Entity(_Object, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
@@ -29,7 +27,6 @@ class Entity(_Object, abc.ABC):
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
try:
|
||||
@@ -51,7 +48,6 @@ class Entity(_Object, abc.ABC):
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
@@ -87,7 +83,7 @@ class Entity(_Object, abc.ABC):
|
||||
if valid := state.check_move_validity(self, next_pos):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(self)
|
||||
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
|
||||
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
|
||||
self._pos = next_pos
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(self)
|
||||
|
||||
@@ -14,10 +14,7 @@ class _Object:
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
try:
|
||||
return self.pos != c.VALUE_NO_POS or False
|
||||
except AttributeError:
|
||||
return False
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
@@ -36,6 +33,17 @@ class _Object:
|
||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||
return f'{self.__class__.__name__}#{self.u_int}'
|
||||
|
||||
# @property
|
||||
# def name(self):
|
||||
# name = f"{self.__class__.__name__}"
|
||||
# if self.bound_entity:
|
||||
# name += f"[{self.bound_entity.name}]"
|
||||
# if self._str_ident is not None:
|
||||
# name += f"({self._str_ident})"
|
||||
# else:
|
||||
# name += f"(#{self.u_int})"
|
||||
# return name
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self._str_ident is not None:
|
||||
@@ -48,6 +56,7 @@ class _Object:
|
||||
return True
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
self._bound_entity = None
|
||||
self._observers = []
|
||||
self._str_ident = str_ident
|
||||
self.u_int = self._identify_and_count_up()
|
||||
@@ -91,73 +100,83 @@ class _Object:
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
class EnvObject(_Object):
|
||||
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
return self._collection.name or self.name
|
||||
except AttributeError:
|
||||
return self.name
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.var_is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
def bind_to(self, entity):
|
||||
self._bound_entity = entity
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
try:
|
||||
return self._collection.var_can_be_bound or False
|
||||
except AttributeError:
|
||||
return False
|
||||
def unbind(self):
|
||||
self._bound_entity = None
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
try:
|
||||
return self._collection.var_can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.var_is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
try:
|
||||
return self._collection.var_has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
try:
|
||||
return self._collection.var_can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
|
||||
def summarize_state(self):
|
||||
return dict(name=str(self.name))
|
||||
# class EnvObject(_Object):
|
||||
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
#
|
||||
# _u_idx = defaultdict(lambda: 0)
|
||||
#
|
||||
# @property
|
||||
# def obs_tag(self):
|
||||
# try:
|
||||
# return self._collection.name or self.name
|
||||
# except AttributeError:
|
||||
# return self.name
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_light(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_light or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_be_bound(self):
|
||||
# try:
|
||||
# return self._collection.var_can_be_bound or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_move(self):
|
||||
# try:
|
||||
# return self._collection.var_can_move or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_pos(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_pos or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_has_position(self):
|
||||
# try:
|
||||
# return self._collection.var_has_position or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_collide(self):
|
||||
# try:
|
||||
# return self._collection.var_can_collide or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def encoding(self):
|
||||
# return c.VALUE_OCCUPIED_CELL
|
||||
#
|
||||
# def __init__(self, **kwargs):
|
||||
# super(EnvObject, self).__init__(**kwargs)
|
||||
#
|
||||
# def change_parent_collection(self, other_collection):
|
||||
# other_collection.add_item(self)
|
||||
# self._collection.delete_env_object(self)
|
||||
# self._collection = other_collection
|
||||
# return self._collection == other_collection
|
||||
#
|
||||
# def summarize_state(self):
|
||||
# return dict(name=str(self.name))
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||
from marl_factory_grid.environment.entity.object import _Object, EnvObject
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
|
||||
|
||||
##########################################################################
|
||||
|
||||
@@ -2,11 +2,11 @@ from typing import List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
|
||||
|
||||
class Collection(_Objects):
|
||||
_entity = EnvObject # entity? object? objects?
|
||||
_entity = _Object # entity?
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
@@ -22,13 +22,13 @@ class Collection(_Objects):
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return False # alles was posmixin hat true
|
||||
return False
|
||||
|
||||
# @property
|
||||
# def var_has_bound(self):
|
||||
# return False # batteries, globalpos, inventories true
|
||||
|
||||
@property
|
||||
def var_has_bound(self):
|
||||
return False # batteries, globalpos, inventories true
|
||||
|
||||
@property # beide bounds hier? inventory can be bound
|
||||
def var_can_be_bound(self):
|
||||
return False
|
||||
|
||||
@@ -40,12 +40,12 @@ class Collection(_Objects):
|
||||
super(Collection, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
def add_item(self, item: Entity):
|
||||
assert self.var_has_position or (len(self) <= self.size)
|
||||
super(Collection, self).add_item(item)
|
||||
return self
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
def delete_env_object(self, env_object):
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
|
||||
@@ -36,7 +36,6 @@ class Entities(_Objects):
|
||||
def guests_that_can_collide(self, pos):
|
||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def empty_positions(self):
|
||||
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
|
||||
@@ -122,7 +122,7 @@ class _Objects:
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
@@ -169,3 +169,15 @@ class _Objects:
|
||||
# FIXME PROTOBUFF
|
||||
# return [e.summarize_state() for e in self]
|
||||
return [e.summarize_state() for e in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
@@ -49,7 +49,7 @@ class SpawnAgents(Rule):
|
||||
agent_conf = state.agents_conf
|
||||
# agents = Agents(lvl_map.size)
|
||||
agents = state[c.AGENT]
|
||||
empty_positions = state.entities.empty_positions[:len(agent_conf)]
|
||||
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
|
||||
for agent_name in agent_conf:
|
||||
actions = agent_conf[agent_name]['actions'].copy()
|
||||
observations = agent_conf[agent_name]['observations'].copy()
|
||||
|
||||
Reference in New Issue
Block a user