mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
bugfixes
This commit is contained in:
parent
ef2fdd5d28
commit
b13dff925b
@ -5,6 +5,19 @@ Agents:
|
|||||||
- Noop
|
- Noop
|
||||||
- ItemAction
|
- ItemAction
|
||||||
Observations:
|
Observations:
|
||||||
|
- Combined:
|
||||||
|
- Other
|
||||||
|
- Walls
|
||||||
|
- GlobalPosition
|
||||||
|
- Battery
|
||||||
|
- ChargePods
|
||||||
|
- DirtPiles
|
||||||
|
- Destinations
|
||||||
|
- Doors
|
||||||
|
- Items
|
||||||
|
- Inventory
|
||||||
|
- DropOffLocations
|
||||||
|
- Maintainers
|
||||||
Wolfgang:
|
Wolfgang:
|
||||||
Actions:
|
Actions:
|
||||||
- Noop
|
- Noop
|
||||||
@ -64,8 +77,6 @@ Rules:
|
|||||||
done_at_collisions: false
|
done_at_collisions: false
|
||||||
AssignGlobalPositions: {}
|
AssignGlobalPositions: {}
|
||||||
DestinationReachAny: {}
|
DestinationReachAny: {}
|
||||||
DestinationReach:
|
|
||||||
n_dests: 1
|
|
||||||
DestinationSpawn:
|
DestinationSpawn:
|
||||||
n_dests: 1
|
n_dests: 1
|
||||||
spawn_frequency: 5
|
spawn_frequency: 5
|
||||||
|
@ -12,8 +12,6 @@ from ...utils.utility_classes import RenderEntity
|
|||||||
class Entity(_Object, abc.ABC):
|
class Entity(_Object, abc.ABC):
|
||||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||||
|
|
||||||
_u_idx = defaultdict(lambda: 0)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||||
@ -29,7 +27,6 @@ class Entity(_Object, abc.ABC):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_can_move(self):
|
def var_can_move(self):
|
||||||
try:
|
try:
|
||||||
@ -51,7 +48,6 @@ class Entity(_Object, abc.ABC):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def x(self):
|
def x(self):
|
||||||
return self.pos[0]
|
return self.pos[0]
|
||||||
@ -87,7 +83,7 @@ class Entity(_Object, abc.ABC):
|
|||||||
if valid := state.check_move_validity(self, next_pos):
|
if valid := state.check_move_validity(self, next_pos):
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_del_entity(self)
|
observer.notify_del_entity(self)
|
||||||
self._view_directory = curr_pos[0]-next_pos[0], curr_pos[1]-next_pos[1]
|
self._view_directory = curr_pos[0] - next_pos[0], curr_pos[1] - next_pos[1]
|
||||||
self._pos = next_pos
|
self._pos = next_pos
|
||||||
for observer in self.observers:
|
for observer in self.observers:
|
||||||
observer.notify_add_entity(self)
|
observer.notify_add_entity(self)
|
||||||
|
@ -14,9 +14,6 @@ class _Object:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
try:
|
|
||||||
return self.pos != c.VALUE_NO_POS or False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -36,6 +33,17 @@ class _Object:
|
|||||||
return f'{self.__class__.__name__}[{self._str_ident}]'
|
return f'{self.__class__.__name__}[{self._str_ident}]'
|
||||||
return f'{self.__class__.__name__}#{self.u_int}'
|
return f'{self.__class__.__name__}#{self.u_int}'
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def name(self):
|
||||||
|
# name = f"{self.__class__.__name__}"
|
||||||
|
# if self.bound_entity:
|
||||||
|
# name += f"[{self.bound_entity.name}]"
|
||||||
|
# if self._str_ident is not None:
|
||||||
|
# name += f"({self._str_ident})"
|
||||||
|
# else:
|
||||||
|
# name += f"(#{self.u_int})"
|
||||||
|
# return name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifier(self):
|
def identifier(self):
|
||||||
if self._str_ident is not None:
|
if self._str_ident is not None:
|
||||||
@ -48,6 +56,7 @@ class _Object:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
def __init__(self, str_ident: Union[str, None] = None, **kwargs):
|
||||||
|
self._bound_entity = None
|
||||||
self._observers = []
|
self._observers = []
|
||||||
self._str_ident = str_ident
|
self._str_ident = str_ident
|
||||||
self.u_int = self._identify_and_count_up()
|
self.u_int = self._identify_and_count_up()
|
||||||
@ -91,73 +100,83 @@ class _Object:
|
|||||||
def belongs_to_entity(self, entity):
|
def belongs_to_entity(self, entity):
|
||||||
return self._bound_entity == entity
|
return self._bound_entity == entity
|
||||||
|
|
||||||
|
|
||||||
class EnvObject(_Object):
|
|
||||||
"""Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
|
||||||
|
|
||||||
_u_idx = defaultdict(lambda: 0)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def bound_entity(self):
|
||||||
try:
|
return self._bound_entity
|
||||||
return self._collection.name or self.name
|
|
||||||
except AttributeError:
|
|
||||||
return self.name
|
|
||||||
|
|
||||||
@property
|
def bind_to(self, entity):
|
||||||
def var_is_blocking_light(self):
|
self._bound_entity = entity
|
||||||
try:
|
|
||||||
return self._collection.var_is_blocking_light or False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
def unbind(self):
|
||||||
def var_can_be_bound(self):
|
self._bound_entity = None
|
||||||
try:
|
|
||||||
return self._collection.var_can_be_bound or False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def var_can_move(self):
|
|
||||||
try:
|
|
||||||
return self._collection.var_can_move or False
|
|
||||||
except AttributeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
# class EnvObject(_Object):
|
||||||
def var_is_blocking_pos(self):
|
# """Objects that hold Information that are observable, but have no position on the environment grid. Inventories etc..."""
|
||||||
try:
|
#
|
||||||
return self._collection.var_is_blocking_pos or False
|
# _u_idx = defaultdict(lambda: 0)
|
||||||
except AttributeError:
|
#
|
||||||
return False
|
# @property
|
||||||
|
# def obs_tag(self):
|
||||||
@property
|
# try:
|
||||||
def var_has_position(self):
|
# return self._collection.name or self.name
|
||||||
try:
|
# except AttributeError:
|
||||||
return self._collection.var_has_position or False
|
# return self.name
|
||||||
except AttributeError:
|
#
|
||||||
return False
|
# @property
|
||||||
|
# def var_is_blocking_light(self):
|
||||||
@property
|
# try:
|
||||||
def var_can_collide(self):
|
# return self._collection.var_is_blocking_light or False
|
||||||
try:
|
# except AttributeError:
|
||||||
return self._collection.var_can_collide or False
|
# return False
|
||||||
except AttributeError:
|
#
|
||||||
return False
|
# @property
|
||||||
|
# def var_can_be_bound(self):
|
||||||
@property
|
# try:
|
||||||
def encoding(self):
|
# return self._collection.var_can_be_bound or False
|
||||||
return c.VALUE_OCCUPIED_CELL
|
# except AttributeError:
|
||||||
|
# return False
|
||||||
def __init__(self, **kwargs):
|
#
|
||||||
super(EnvObject, self).__init__(**kwargs)
|
# @property
|
||||||
|
# def var_can_move(self):
|
||||||
def change_parent_collection(self, other_collection):
|
# try:
|
||||||
other_collection.add_item(self)
|
# return self._collection.var_can_move or False
|
||||||
self._collection.delete_env_object(self)
|
# except AttributeError:
|
||||||
self._collection = other_collection
|
# return False
|
||||||
return self._collection == other_collection
|
#
|
||||||
|
# @property
|
||||||
def summarize_state(self):
|
# def var_is_blocking_pos(self):
|
||||||
return dict(name=str(self.name))
|
# try:
|
||||||
|
# return self._collection.var_is_blocking_pos or False
|
||||||
|
# except AttributeError:
|
||||||
|
# return False
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def var_has_position(self):
|
||||||
|
# try:
|
||||||
|
# return self._collection.var_has_position or False
|
||||||
|
# except AttributeError:
|
||||||
|
# return False
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def var_can_collide(self):
|
||||||
|
# try:
|
||||||
|
# return self._collection.var_can_collide or False
|
||||||
|
# except AttributeError:
|
||||||
|
# return False
|
||||||
|
#
|
||||||
|
# @property
|
||||||
|
# def encoding(self):
|
||||||
|
# return c.VALUE_OCCUPIED_CELL
|
||||||
|
#
|
||||||
|
# def __init__(self, **kwargs):
|
||||||
|
# super(EnvObject, self).__init__(**kwargs)
|
||||||
|
#
|
||||||
|
# def change_parent_collection(self, other_collection):
|
||||||
|
# other_collection.add_item(self)
|
||||||
|
# self._collection.delete_env_object(self)
|
||||||
|
# self._collection = other_collection
|
||||||
|
# return self._collection == other_collection
|
||||||
|
#
|
||||||
|
# def summarize_state(self):
|
||||||
|
# return dict(name=str(self.name))
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.mixin import BoundEntityMixin
|
from marl_factory_grid.environment.entity.object import _Object
|
||||||
from marl_factory_grid.environment.entity.object import _Object, EnvObject
|
|
||||||
|
|
||||||
|
|
||||||
##########################################################################
|
##########################################################################
|
||||||
|
@ -2,11 +2,11 @@ from typing import List, Tuple
|
|||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.environment.groups.objects import _Objects
|
from marl_factory_grid.environment.groups.objects import _Objects
|
||||||
from marl_factory_grid.environment.entity.object import EnvObject
|
from marl_factory_grid.environment.entity.object import _Object
|
||||||
|
|
||||||
|
|
||||||
class Collection(_Objects):
|
class Collection(_Objects):
|
||||||
_entity = EnvObject # entity? object? objects?
|
_entity = _Object # entity?
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
@ -22,13 +22,13 @@ class Collection(_Objects):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
return False # alles was posmixin hat true
|
return False
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def var_has_bound(self):
|
||||||
|
# return False # batteries, globalpos, inventories true
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_bound(self):
|
|
||||||
return False # batteries, globalpos, inventories true
|
|
||||||
|
|
||||||
@property # beide bounds hier? inventory can be bound
|
|
||||||
def var_can_be_bound(self):
|
def var_can_be_bound(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -40,12 +40,12 @@ class Collection(_Objects):
|
|||||||
super(Collection, self).__init__(*args, **kwargs)
|
super(Collection, self).__init__(*args, **kwargs)
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def add_item(self, item: EnvObject):
|
def add_item(self, item: Entity):
|
||||||
assert self.var_has_position or (len(self) <= self.size)
|
assert self.var_has_position or (len(self) <= self.size)
|
||||||
super(Collection, self).add_item(item)
|
super(Collection, self).add_item(item)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def delete_env_object(self, env_object: EnvObject):
|
def delete_env_object(self, env_object):
|
||||||
del self[env_object.name]
|
del self[env_object.name]
|
||||||
|
|
||||||
def delete_env_object_by_name(self, name):
|
def delete_env_object_by_name(self, name):
|
||||||
|
@ -36,7 +36,6 @@ class Entities(_Objects):
|
|||||||
def guests_that_can_collide(self, pos):
|
def guests_that_can_collide(self, pos):
|
||||||
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
return[x for val in self.pos_dict[pos] for x in val if x.var_can_collide]
|
||||||
|
|
||||||
@property
|
|
||||||
def empty_positions(self):
|
def empty_positions(self):
|
||||||
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
empty_positions= [key for key in self.floorlist if self.pos_dict[key]]
|
||||||
shuffle(empty_positions)
|
shuffle(empty_positions)
|
||||||
|
@ -122,7 +122,7 @@ class _Objects:
|
|||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
repr_dict = { key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
repr_dict = {key: val for key, val in self._data.items() if key not in [c.WALLS]}
|
||||||
return f'{self.__class__.__name__}[{repr_dict}]'
|
return f'{self.__class__.__name__}[{repr_dict}]'
|
||||||
|
|
||||||
def spawn(self, n: int):
|
def spawn(self, n: int):
|
||||||
@ -169,3 +169,15 @@ class _Objects:
|
|||||||
# FIXME PROTOBUFF
|
# FIXME PROTOBUFF
|
||||||
# return [e.summarize_state() for e in self]
|
# return [e.summarize_state() for e in self]
|
||||||
return [e.summarize_state() for e in self]
|
return [e.summarize_state() for e in self]
|
||||||
|
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||||
|
except (StopIteration, AttributeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||||
|
except (StopIteration, AttributeError):
|
||||||
|
return None
|
||||||
|
@ -49,7 +49,7 @@ class SpawnAgents(Rule):
|
|||||||
agent_conf = state.agents_conf
|
agent_conf = state.agents_conf
|
||||||
# agents = Agents(lvl_map.size)
|
# agents = Agents(lvl_map.size)
|
||||||
agents = state[c.AGENT]
|
agents = state[c.AGENT]
|
||||||
empty_positions = state.entities.empty_positions[:len(agent_conf)]
|
empty_positions = state.entities.empty_positions()[:len(agent_conf)]
|
||||||
for agent_name in agent_conf:
|
for agent_name in agent_conf:
|
||||||
actions = agent_conf[agent_name]['actions'].copy()
|
actions = agent_conf[agent_name]['actions'].copy()
|
||||||
observations = agent_conf[agent_name]['observations'].copy()
|
observations = agent_conf[agent_name]['observations'].copy()
|
||||||
|
@ -20,7 +20,7 @@ class Batteries(Collection):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def var_has_position(self):
|
def var_has_position(self):
|
||||||
return True
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
|
@ -36,7 +36,6 @@ class DestinationReachAll(Rule):
|
|||||||
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
results.append(TickResult(self.name, validity=c.VALID, reward=r.DEST_REACHED, entity=agent))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
return [DoneResult(self.name, validity=c.VALID, reward=r.DEST_REACHED)]
|
||||||
@ -56,7 +55,7 @@ class DestinationReachAny(DestinationReachAll):
|
|||||||
|
|
||||||
class DestinationSpawn(Rule):
|
class DestinationSpawn(Rule):
|
||||||
|
|
||||||
def __init__(self, n_dests: int = 1,
|
def __init__(self, n_dests: int = 1, spawn_frequency: int = 5,
|
||||||
spawn_mode: str = d.MODE_GROUPED):
|
spawn_mode: str = d.MODE_GROUPED):
|
||||||
super(DestinationSpawn, self).__init__()
|
super(DestinationSpawn, self).__init__()
|
||||||
self.n_dests = n_dests
|
self.n_dests = n_dests
|
||||||
|
@ -5,7 +5,7 @@ from marl_factory_grid.environment import constants as c
|
|||||||
|
|
||||||
from marl_factory_grid.environment.groups.collection import Collection
|
from marl_factory_grid.environment.groups.collection import Collection
|
||||||
from marl_factory_grid.environment.groups.objects import _Objects
|
from marl_factory_grid.environment.groups.objects import _Objects
|
||||||
from marl_factory_grid.environment.groups.mixins import IsBoundMixin, HasBoundMixin
|
from marl_factory_grid.environment.groups.mixins import IsBoundMixin
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
from marl_factory_grid.modules.items.entitites import Item, DropOffLocation
|
||||||
|
|
||||||
@ -45,6 +45,10 @@ class Items(Collection):
|
|||||||
class Inventory(IsBoundMixin, Collection):
|
class Inventory(IsBoundMixin, Collection):
|
||||||
_accepted_objects = Item
|
_accepted_objects = Item
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_be_bound(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
return self.name
|
return self.name
|
||||||
@ -69,7 +73,7 @@ class Inventory(IsBoundMixin, Collection):
|
|||||||
self._collection = collection
|
self._collection = collection
|
||||||
|
|
||||||
|
|
||||||
class Inventories(HasBoundMixin, _Objects):
|
class Inventories(_Objects):
|
||||||
_entity = Inventory
|
_entity = Inventory
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -3,10 +3,12 @@ from marl_factory_grid.modules.zones import Zone
|
|||||||
|
|
||||||
|
|
||||||
class Zones(_Objects):
|
class Zones(_Objects):
|
||||||
|
|
||||||
symbol = None
|
symbol = None
|
||||||
_entity = Zone
|
_entity = Zone
|
||||||
var_can_move = False
|
|
||||||
|
@property
|
||||||
|
def var_can_move(self):
|
||||||
|
return False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
||||||
|
@ -103,6 +103,7 @@ class OBSBuilder(object):
|
|||||||
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
obs = np.zeros((len(agent_want_obs), self.obs_shape[0], self.obs_shape[1]))
|
||||||
|
|
||||||
for idx, l_name in enumerate(agent_want_obs):
|
for idx, l_name in enumerate(agent_want_obs):
|
||||||
|
print(l_name)
|
||||||
try:
|
try:
|
||||||
obs[idx] = pre_sort_obs[l_name]
|
obs[idx] = pre_sort_obs[l_name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -141,6 +142,8 @@ class OBSBuilder(object):
|
|||||||
try:
|
try:
|
||||||
v = e.encoding
|
v = e.encoding
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
print(e)
|
||||||
|
print(e.var_has_position)
|
||||||
raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"')
|
raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"')
|
||||||
try:
|
try:
|
||||||
np.put(obs[idx], range(len(v)), v, mode='raise')
|
np.put(obs[idx], range(len(v)), v, mode='raise')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user