mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 23:06:43 +02:00
bugfixes
This commit is contained in:
parent
ef2fdd5d28
commit
b13dff925b
@ -5,6 +5,19 @@ Agents:
|
||||
- Noop
|
||||
- ItemAction
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Destinations
|
||||
- Doors
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Maintainers
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Noop
|
||||
@ -64,8 +77,6 @@ Rules:
|
||||
done_at_collisions: false
|
||||
AssignGlobalPositions: {}
|
||||
DestinationReachAny: {}
|
||||
DestinationReach:
|
||||
n_dests: 1
|
||||
DestinationSpawn:
|
||||
n_dests: 1
|
||||
spawn_frequency: 5
|
||||
|
@ -12,8 +12,6 @@ from ...utils.utility_classes import RenderEntity
|
||||
class Entity(_Object, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
@ -29,7 +27,6 @@ class Entity(_Object, abc.ABC):
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
try:
|
||||
@ -51,7 +48,6 @@ class Entity(_Object, abc.ABC):
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self.pos[0]
|
||||
@ -87,7 +83,7 @@ class Entity(_Object, abc.ABC):
|
||||
if valid := state.check_move_validity(self, next_pos):
|
||||
for observer in self.observers:
|
||||
observer.notify_del_entity(self)
|
||||
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
|
||||
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
|
||||
self._pos = next_pos
|
||||
for observer in self.observers:
|
||||
observer.notify_add_entity(self)
|
||||
|
@ -14,10 +14,7 @@ class _Object:
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
try:
|
||||
return self.pos != c.VALUE_NO_POS or False
|
||||
except AttributeError:
|
||||
return False
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
@ -36,6 +33,17 @@ class _Object:
|
||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||
return f'{self.__class__.__name__}#{self.u_int}'
|
||||
|
||||
# @property
|
||||
# def name(self):
|
||||
# name = f"{self.__class__.__name__}"
|
||||
# if self.bound_entity:
|
||||
# name += f"[{self.bound_entity.name}]"
|
||||
# if self._str_ident is not None:
|
||||
# name += f"({self._str_ident})"
|
||||
# else:
|
||||
# name += f"(#{self.u_int})"
|
||||
# return name
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
if self._str_ident is not None:
|
||||
@ -48,6 +56,7 @@ class _Object:
|
||||
return True
|
||||
|
||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||
self._bound_entity = None
|
||||
self._observers = []
|
||||
self._str_ident = str_ident
|
||||
self.u_int = self._identify_and_count_up()
|
||||
@ -91,73 +100,83 @@ class _Object:
|
||||
def belongs_to_entity(self, entity):
|
||||
return self._bound_entity == entity
|
||||
|
||||
|
||||
class EnvObject(_Object):
|
||||
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
|
||||
_u_idx = defaultdict(lambda: 0)
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
try:
|
||||
return self._collection.name or self.name
|
||||
except AttributeError:
|
||||
return self.name
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.var_is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
def bind_to(self, entity):
|
||||
self._bound_entity = entity
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
try:
|
||||
return self._collection.var_can_be_bound or False
|
||||
except AttributeError:
|
||||
return False
|
||||
def unbind(self):
|
||||
self._bound_entity = None
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
try:
|
||||
return self._collection.var_can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.var_is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
try:
|
||||
return self._collection.var_has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
try:
|
||||
return self._collection.var_can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.VALUE_OCCUPIED_CELL
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
|
||||
def summarize_state(self):
|
||||
return dict(name=str(self.name))
|
||||
# class EnvObject(_Object):
|
||||
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||
#
|
||||
# _u_idx = defaultdict(lambda: 0)
|
||||
#
|
||||
# @property
|
||||
# def obs_tag(self):
|
||||
# try:
|
||||
# return self._collection.name or self.name
|
||||
# except AttributeError:
|
||||
# return self.name
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_light(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_light or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_be_bound(self):
|
||||
# try:
|
||||
# return self._collection.var_can_be_bound or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_move(self):
|
||||
# try:
|
||||
# return self._collection.var_can_move or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_is_blocking_pos(self):
|
||||
# try:
|
||||
# return self._collection.var_is_blocking_pos or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_has_position(self):
|
||||
# try:
|
||||
# return self._collection.var_has_position or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def var_can_collide(self):
|
||||
# try:
|
||||
# return self._collection.var_can_collide or False
|
||||
# except AttributeError:
|
||||
# return False
|
||||
#
|
||||
# @property
|
||||
# def encoding(self):
|
||||
# return c.VALUE_OCCUPIED_CELL
|
||||
#
|
||||
# def __init__(self, **kwargs):
|
||||
# super(EnvObject, self).__init__(**kwargs)
|
||||
#
|
||||
# def change_parent_collection(self, other_collection):
|
||||
# other_collection.add_item(self)
|
||||
# self._collection.delete_env_object(self)
|
||||
# self._collection = other_collection
|
||||
# return self._collection == other_collection
|
||||
#
|
||||
# def summarize_state(self):
|
||||
# return dict(name=str(self.name))
|
||||
|
@ -1,9 +1,6 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
||||
from marl_factory_grid.environment.entity.object import _Object, EnvObject
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
|
||||
|
||||
##########################################################################
|
||||
|
@ -2,11 +2,11 @@ from typing import List, Tuple
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
|
||||
|
||||
class Collection(_Objects):
|
||||
_entity = EnvObject # entity? object? objects?
|
||||
_entity = _Object # entity?
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
@ -22,13 +22,13 @@ class Collection(_Objects):
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return False # alles was posmixin hat true
|
||||
return False
|
||||
|
||||
# @property
|
||||
# def var_has_bound(self):
|
||||
# return False # batteries, globalpos, inventories true
|
||||
|
||||
@property
|
||||
def var_has_bound(self):
|
||||
return False # batteries, globalpos, inventories true
|
||||
|
||||
@property # beide bounds hier? inventory can be bound
|
||||
def var_can_be_bound(self):
|
||||
return False
|
||||
|
||||
@ -40,12 +40,12 @@ class Collection(_Objects):
|
||||
super(Collection, self).__init__(*args, **kwargs)
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
def add_item(self, item: Entity):
|
||||
assert self.var_has_position or (len(self) <= self.size)
|
||||
super(Collection, self).add_item(item)
|
||||
return self
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
def delete_env_object(self, env_object):
|
||||
del self[env_object.name]
|
||||
|
||||
def delete_env_object_by_name(self, name):
|
||||
|
@ -36,7 +36,6 @@ class Entities(_Objects):
|
||||
def guests_that_can_collide(self, pos):
|
||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def empty_positions(self):
|
||||
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
||||
shuffle(empty_positions)
|
||||
|
@ -122,7 +122,7 @@ class _Objects:
|
||||
raise TypeError
|
||||
|
||||
def __repr__(self):
|
||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
@ -169,3 +169,15 @@ class _Objects:
|
||||
# FIXME PROTOBUFF
|
||||
# return [e.summarize_state() for e in self]
|
||||
return [e.summarize_state() for e in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except (StopIteration, AttributeError):
|
||||
return None
|
||||
|
@ -49,7 +49,7 @@ class SpawnAgents(Rule):
|
||||
agent_conf = state.agents_conf
|
||||
# agents = Agents(lvl_map.size)
|
||||
agents = state[c.AGENT]
|
||||
empty_positions = state.entities.empty_positions[:len(agent_conf)]
|
||||
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
|
||||
for agent_name in agent_conf:
|
||||
actions = agent_conf[agent_name]['actions'].copy()
|
||||
observations = agent_conf[agent_name]['observations'].copy()
|
||||
|
@ -20,7 +20,7 @@ class Batteries(Collection):
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
|
@ -36,7 +36,6 @@ class DestinationReachAll(Rule):
|
||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||
return results
|
||||
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||
@ -56,7 +55,7 @@ class DestinationReachAny(DestinationReachAll):
|
||||
|
||||
class DestinationSpawn(Rule):
|
||||
|
||||
def __init__(self, n_dests: int = 1,
|
||||
def __init__(self, n_dests: int = 1, spawn_frequency: int = 5,
|
||||
spawn_mode: str = d.MODE_GROUPED):
|
||||
super(DestinationSpawn, self).__init__()
|
||||
self.n_dests = n_dests
|
||||
|
@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.environment.groups.collection import Collection
|
||||
from marl_factory_grid.environment.groups.objects import _Objects
|
||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin, HasBoundMixin
|
||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||
|
||||
@ -45,6 +45,10 @@ class Items(Collection):
|
||||
class Inventory(IsBoundMixin, Collection):
|
||||
_accepted_objects = Item
|
||||
|
||||
@property
|
||||
def var_can_be_bound(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def obs_tag(self):
|
||||
return self.name
|
||||
@ -69,7 +73,7 @@ class Inventory(IsBoundMixin, Collection):
|
||||
self._collection = collection
|
||||
|
||||
|
||||
class Inventories(HasBoundMixin, _Objects):
|
||||
class Inventories(_Objects):
|
||||
_entity = Inventory
|
||||
|
||||
@property
|
||||
|
@ -3,10 +3,12 @@ from marl_factory_grid.modules.zones import Zone
|
||||
|
||||
|
||||
class Zones(_Objects):
|
||||
|
||||
symbol = None
|
||||
_entity = Zone
|
||||
var_can_move = False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
||||
|
@ -103,6 +103,7 @@ class OBSBuilder(object):
|
||||
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||
|
||||
for idx, l_name in enumerate(agent_want_obs):
|
||||
print(l_name)
|
||||
try:
|
||||
obs[idx] = pre_sort_obs[l_name]
|
||||
except KeyError:
|
||||
@ -141,6 +142,8 @@ class OBSBuilder(object):
|
||||
try:
|
||||
v = e.encoding
|
||||
except AttributeError:
|
||||
print(e)
|
||||
print(e.var_has_position)
|
||||
raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"')
|
||||
try:
|
||||
np.put(obs[idx], range(len(v)), v, mode='raise')
|
||||
|
Loading…
x
Reference in New Issue
Block a user