This commit is contained in:
Chanumask 2023-10-27 13:08:01 +02:00
parent ef2fdd5d28
commit b13dff925b
13 changed files with 141 additions and 99 deletions

View File

@ -5,6 +5,19 @@ Agents:
- Noop
- ItemAction
Observations:
- Combined:
- Other
- Walls
- GlobalPosition
- Battery
- ChargePods
- DirtPiles
- Destinations
- Doors
- Items
- Inventory
- DropOffLocations
- Maintainers
Wolfgang:
Actions:
- Noop
@ -64,8 +77,6 @@ Rules:
done_at_collisions: false
AssignGlobalPositions: {}
DestinationReachAny: {}
DestinationReach:
n_dests: 1
DestinationSpawn:
n_dests: 1
spawn_frequency: 5

View File

@ -12,8 +12,6 @@ from ...utils.utility_classes import RenderEntity
class Entity(_Object, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
_u_idx = defaultdict(lambda: 0)
@property
def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
@ -29,7 +27,6 @@ class Entity(_Object, abc.ABC):
except AttributeError:
return False
@property
def var_can_move(self):
try:
@ -51,7 +48,6 @@ class Entity(_Object, abc.ABC):
except AttributeError:
return False
@property
def x(self):
return self.pos[0]
@ -87,7 +83,7 @@ class Entity(_Object, abc.ABC):
if valid := state.check_move_validity(self, next_pos):
for observer in self.observers:
observer.notify_del_entity(self)
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
self._pos = next_pos
for observer in self.observers:
observer.notify_add_entity(self)

View File

@ -14,9 +14,6 @@ class _Object:
@property
def var_has_position(self):
try:
return self.pos != c.VALUE_NO_POS or False
except AttributeError:
return False
@property
@ -36,6 +33,17 @@ class _Object:
return f'{self.__class__.__name__}[{self._str_ident}]'
return f'{self.__class__.__name__}#{self.u_int}'
# @property
# def name(self):
# name = f"{self.__class__.__name__}"
# if self.bound_entity:
# name += f"[{self.bound_entity.name}]"
# if self._str_ident is not None:
# name += f"({self._str_ident})"
# else:
# name += f"(#{self.u_int})"
# return name
@property
def identifier(self):
if self._str_ident is not None:
@ -48,6 +56,7 @@ class _Object:
return True
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
self._bound_entity = None
self._observers = []
self._str_ident = str_ident
self.u_int = self._identify_and_count_up()
@ -91,73 +100,83 @@ class _Object:
def belongs_to_entity(self, entity):
return self._bound_entity == entity
class EnvObject(_Object):
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
_u_idx = defaultdict(lambda: 0)
@property
def obs_tag(self):
try:
return self._collection.name or self.name
except AttributeError:
return self.name
def bound_entity(self):
return self._bound_entity
@property
def var_is_blocking_light(self):
try:
return self._collection.var_is_blocking_light or False
except AttributeError:
return False
def bind_to(self, entity):
self._bound_entity = entity
@property
def var_can_be_bound(self):
try:
return self._collection.var_can_be_bound or False
except AttributeError:
return False
def unbind(self):
self._bound_entity = None
@property
def var_can_move(self):
try:
return self._collection.var_can_move or False
except AttributeError:
return False
@property
def var_is_blocking_pos(self):
try:
return self._collection.var_is_blocking_pos or False
except AttributeError:
return False
@property
def var_has_position(self):
try:
return self._collection.var_has_position or False
except AttributeError:
return False
@property
def var_can_collide(self):
try:
return self._collection.var_can_collide or False
except AttributeError:
return False
@property
def encoding(self):
return c.VALUE_OCCUPIED_CELL
def __init__(self, **kwargs):
super(EnvObject, self).__init__(**kwargs)
def change_parent_collection(self, other_collection):
other_collection.add_item(self)
self._collection.delete_env_object(self)
self._collection = other_collection
return self._collection == other_collection
def summarize_state(self):
return dict(name=str(self.name))
# class EnvObject(_Object):
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
#
# _u_idx = defaultdict(lambda: 0)
#
# @property
# def obs_tag(self):
# try:
# return self._collection.name or self.name
# except AttributeError:
# return self.name
#
# @property
# def var_is_blocking_light(self):
# try:
# return self._collection.var_is_blocking_light or False
# except AttributeError:
# return False
#
# @property
# def var_can_be_bound(self):
# try:
# return self._collection.var_can_be_bound or False
# except AttributeError:
# return False
#
# @property
# def var_can_move(self):
# try:
# return self._collection.var_can_move or False
# except AttributeError:
# return False
#
# @property
# def var_is_blocking_pos(self):
# try:
# return self._collection.var_is_blocking_pos or False
# except AttributeError:
# return False
#
# @property
# def var_has_position(self):
# try:
# return self._collection.var_has_position or False
# except AttributeError:
# return False
#
# @property
# def var_can_collide(self):
# try:
# return self._collection.var_can_collide or False
# except AttributeError:
# return False
#
# @property
# def encoding(self):
# return c.VALUE_OCCUPIED_CELL
#
# def __init__(self, **kwargs):
# super(EnvObject, self).__init__(**kwargs)
#
# def change_parent_collection(self, other_collection):
# other_collection.add_item(self)
# self._collection.delete_env_object(self)
# self._collection = other_collection
# return self._collection == other_collection
#
# def summarize_state(self):
# return dict(name=str(self.name))

View File

@ -1,9 +1,6 @@
import math
import numpy as np
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
from marl_factory_grid.environment.entity.object import _Object, EnvObject
from marl_factory_grid.environment.entity.object import _Object
##########################################################################

View File

@ -2,11 +2,11 @@ from typing import List, Tuple
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.entity.object import EnvObject
from marl_factory_grid.environment.entity.object import _Object
class Collection(_Objects):
_entity = EnvObject # entity? object? objects?
_entity = _Object # entity?
@property
def var_is_blocking_light(self):
@ -22,13 +22,13 @@ class Collection(_Objects):
@property
def var_has_position(self):
return False # alles was posmixin hat true
return False
# @property
# def var_has_bound(self):
# return False # batteries, globalpos, inventories true
@property
def var_has_bound(self):
return False # batteries, globalpos, inventories true
@property # beide bounds hier? inventory can be bound
def var_can_be_bound(self):
return False
@ -40,12 +40,12 @@ class Collection(_Objects):
super(Collection, self).__init__(*args, **kwargs)
self.size = size
def add_item(self, item: EnvObject):
def add_item(self, item: Entity):
assert self.var_has_position or (len(self) <= self.size)
super(Collection, self).add_item(item)
return self
def delete_env_object(self, env_object: EnvObject):
def delete_env_object(self, env_object):
del self[env_object.name]
def delete_env_object_by_name(self, name):

View File

@ -36,7 +36,6 @@ class Entities(_Objects):
def guests_that_can_collide(self, pos):
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
@property
def empty_positions(self):
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
shuffle(empty_positions)

View File

@ -122,7 +122,7 @@ class _Objects:
raise TypeError
def __repr__(self):
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
return f'{self.__class__.__name__}[{repr_dict}]'
def spawn(self, n: int):
@ -169,3 +169,15 @@ class _Objects:
# FIXME PROTOBUFF
# return [e.summarize_state() for e in self]
return [e.summarize_state() for e in self]
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except (StopIteration, AttributeError):
return None

View File

@ -49,7 +49,7 @@ class SpawnAgents(Rule):
agent_conf = state.agents_conf
# agents = Agents(lvl_map.size)
agents = state[c.AGENT]
empty_positions = state.entities.empty_positions[:len(agent_conf)]
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
for agent_name in agent_conf:
actions = agent_conf[agent_name]['actions'].copy()
observations = agent_conf[agent_name]['observations'].copy()

View File

@ -20,7 +20,7 @@ class Batteries(Collection):
@property
def var_has_position(self):
return True
return False
@property
def obs_tag(self):

View File

@ -36,7 +36,6 @@ class DestinationReachAll(Rule):
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
return results
def on_check_done(self, state) -> List[DoneResult]:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
@ -56,7 +55,7 @@ class DestinationReachAny(DestinationReachAll):
class DestinationSpawn(Rule):
def __init__(self, n_dests: int = 1,
def __init__(self, n_dests: int = 1, spawn_frequency: int = 5,
spawn_mode: str = d.MODE_GROUPED):
super(DestinationSpawn, self).__init__()
self.n_dests = n_dests

View File

@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.groups.objects import _Objects
from marl_factory_grid.environment.groups.mixins import IsBoundMixin, HasBoundMixin
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
@ -45,6 +45,10 @@ class Items(Collection):
class Inventory(IsBoundMixin, Collection):
_accepted_objects = Item
@property
def var_can_be_bound(self):
return True
@property
def obs_tag(self):
return self.name
@ -69,7 +73,7 @@ class Inventory(IsBoundMixin, Collection):
self._collection = collection
class Inventories(HasBoundMixin, _Objects):
class Inventories(_Objects):
_entity = Inventory
@property

View File

@ -3,10 +3,12 @@ from marl_factory_grid.modules.zones import Zone
class Zones(_Objects):
symbol = None
_entity = Zone
var_can_move = False
@property
def var_can_move(self):
return False
def __init__(self, *args, **kwargs):
super(Zones, self).__init__(*args, can_collide=True, **kwargs)

View File

@ -103,6 +103,7 @@ class OBSBuilder(object):
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
for idx, l_name in enumerate(agent_want_obs):
print(l_name)
try:
obs[idx] = pre_sort_obs[l_name]
except KeyError:
@ -141,6 +142,8 @@ class OBSBuilder(object):
try:
v = e.encoding
except AttributeError:
print(e)
print(e.var_has_position)
raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"')
try:
np.put(obs[idx], range(len(v)), v, mode='raise')