This commit is contained in:
Steffen Illium
2023-07-06 12:01:25 +02:00
parent dc134d71e0
commit 836495a884
72 changed files with 742 additions and 298 deletions

View File

@ -1,6 +1,6 @@
from .environment.factory import BaseFactory from .environment import *
from .environment.factory import OBSBuilder from .modules import *
from .utils import *
from .utils.tools import ConfigExplainer
from .quickstart import init from .quickstart import init

View File

@ -1,11 +1,10 @@
import itertools
from random import choice from random import choice
import numpy as np import numpy as np
import networkx as nx
from networkx.algorithms.approximation import traveling_salesman as tsp from networkx.algorithms.approximation import traveling_salesman as tsp
from marl_factory_grid.algorithms.static.utils import points_to_graph
from marl_factory_grid.modules.doors import constants as do from marl_factory_grid.modules.doors import constants as do
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.helpers import MOVEMAP from marl_factory_grid.utils.helpers import MOVEMAP
@ -15,41 +14,6 @@ from abc import abstractmethod, ABC
future_planning = 7 future_planning = 7
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
"""
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
There are three combinations of settings:
Allow all neigbors: Distance(a, b) <= sqrt(2)
Allow only manhattan: Distance(a, b) == 1
Allow only euclidean: Distance(a, b) == sqrt(2)
:param coordiniates_or_tiles: A set of coordinates.
:type coordiniates_or_tiles: Tiles
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
:type: bool
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
:type: bool
:return: A graph with nodes that are conneceted as specified by the parameters.
:rtype: nx.Graph
"""
assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'):
coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
graph.add_edge(a, b)
return graph
class TSPBaseAgent(ABC): class TSPBaseAgent(ABC):
def __init__(self, state, agent_i, static_problem: bool = True): def __init__(self, state, agent_i, static_problem: bool = True):

View File

@ -0,0 +1,39 @@
import itertools
import networkx as nx
import numpy as np
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
"""
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
There are three combinations of settings:
Allow all neigbors: Distance(a, b) <= sqrt(2)
Allow only manhattan: Distance(a, b) == 1
Allow only euclidean: Distance(a, b) == sqrt(2)
:param coordiniates_or_tiles: A set of coordinates.
:type coordiniates_or_tiles: Tiles
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
:type: bool
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
:type: bool
:return: A graph with nodes that are conneceted as specified by the parameters.
:rtype: nx.Graph
"""
assert allow_euclidean_connections or allow_manhattan_connections
if hasattr(coordiniates_or_tiles, 'positions'):
coordiniates_or_tiles = coordiniates_or_tiles.positions
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
graph = nx.Graph()
for a, b in possible_connections:
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
graph.add_edge(a, b)
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
graph.add_edge(a, b)
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
graph.add_edge(a, b)
return graph

View File

@ -1,68 +1,89 @@
---
General:
level_name: rooms
env_seed: 69
verbose: !!bool False
pomdp_r: 5
individual_rewards: !!bool True
Entities:
Defaults: {}
DirtPiles:
initial_dirt_ratio: 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
dirt_spawn_r_var: 0.05 # How much does the dirt spawn amount vary?
initial_amount: 3
max_local_amount: 5 # Max dirt amount per tile.
max_global_amount: 20 # Max dirt amount in the whole environment.
Doors:
closed_on_init: True
auto_close_interval: 10
indicate_area: False
Agents: Agents:
Wolfgang: Wolfgang:
Actions: Actions:
- Move8 - Noop
- Noop - BtryCharge
- DoorUse - CleanUp
- CleanUp - DestAction
- DoorUse
- ItemAction
- Move8
Observations: Observations:
- Self - Combined:
- Placeholder - Other
- Walls - Walls
- DirtPiles - GlobalPosition
- Placeholder - Battery
- Doors - ChargePods
- Doors - DirtPiles
Björn: - Destinations
Actions: - Doors
# Move4, Noop - Items
- Move4 - Inventory
- DoorUse - DropOffLocations
- CleanUp - Machines
Observations: - Maintainers
- Defaults Entities:
- Combined Batteries: {}
Jürgen: ChargePods: {}
Actions: Destinations: {}
# Move4, Noop DirtPiles:
- Defaults clean_amount: 1
- DoorUse dirt_spawn_r_var: 0.1
- CleanUp initial_amount: 2
Observations: initial_dirt_ratio: 0.05
- Walls max_global_amount: 20
- Placeholder max_local_amount: 5
- Agent[Björn] Doors: {}
DropOffLocations: {}
GlobalPositions: {}
Inventories: {}
Items: {}
Machines: {}
Maintainers: {}
Zones: {}
ReachedDestinations: {}
General:
env_seed: 69
individual_rewards: true
level_name: large
pomdp_r: 3
verbose: false
Rules: Rules:
Defaults: {} Btry:
initial_charge: 0.8
per_action_costs: 0.02
BtryDoneAtDischarge: {}
Collision: Collision:
done_at_collisions: !!bool False done_at_collisions: false
DirtRespawnRule: AssignGlobalPositions: {}
spawn_freq: 5 DestinationDone: {}
DirtSmearOnMove: DestinationReach:
smear_amount: 0.12 n_dests: 1
DoorAutoClose: {} tiles: null
DestinationSpawn:
n_dests: 1
spawn_frequency: 5
spawn_mode: GROUPED
DirtAllCleanDone: {} DirtAllCleanDone: {}
Assets: DirtRespawnRule:
- Defaults spawn_freq: 15
- Dirt DirtSmearOnMove:
- Doors smear_amount: 0.2
DoorAutoClose:
close_frequency: 10
ItemRules:
max_dropoff_storage_size: 0
n_items: 5
n_locations: 5
spawn_frequency: 15
MachineRule:
n_machines: 2
MaintenanceRule:
n_maintainer: 1
MaxStepsReached:
max_steps: 500
# AgentSingleZonePlacement:
# n_zones: 4

View File

@ -98,3 +98,5 @@ class NorthWest(Move):
Move4 = [North, East, South, West] Move4 = [North, East, South, West]
# noinspection PyTypeChecker # noinspection PyTypeChecker
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest] Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
ALL_BASEACTIONS = Move8 + [Noop]

View File

@ -9,15 +9,13 @@ WALL = 'Wall' # Identifier of Wall-objects and
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups). WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
LEVEL = 'Level' # Identifier of Level-objects and groups (groups). LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups). AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
OTHERS = 'Other' OTHERS = 'Other'
COMBINED = 'Combined' COMBINED = 'Combined'
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
# Attributes # Attributes
IS_BLOCKING_LIGHT = 'is_blocking_light' IS_BLOCKING_LIGHT = 'var_is_blocking_light'
HAS_POSITION = 'has_position' HAS_POSITION = 'var_has_position'
HAS_NO_POSITION = 'has_no_position' HAS_NO_POSITION = 'has_no_position'
ALL = 'All' ALL = 'All'

View File

@ -1,6 +1,5 @@
from typing import List, Union from typing import List, Union
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.render import RenderEntity
@ -8,6 +7,8 @@ from marl_factory_grid.utils import renderer
from marl_factory_grid.utils.helpers import is_move from marl_factory_grid.utils.helpers import is_move
from marl_factory_grid.utils.results import ActionResult, Result from marl_factory_grid.utils.results import ActionResult, Result
from marl_factory_grid.environment import constants as c
class Agent(Entity): class Agent(Entity):
@ -24,7 +25,7 @@ class Agent(Entity):
return self._observations return self._observations
@property @property
def can_collide(self): def var_can_collide(self):
return True return True
def step_result(self): def step_result(self):

View File

@ -1,15 +1,20 @@
import abc import abc
from marl_factory_grid.environment import constants as c from .. import constants as c
from marl_factory_grid.environment.entity.object import EnvObject from .object import EnvObject
from marl_factory_grid.utils.render import RenderEntity from ...utils.render import RenderEntity
from ...utils.results import ActionResult
class Entity(EnvObject, abc.ABC): class Entity(EnvObject, abc.ABC):
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
@property @property
def has_position(self): def state(self):
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
@property
def var_has_position(self):
return self.pos != c.VALUE_NO_POS return self.pos != c.VALUE_NO_POS
@property @property
@ -64,12 +69,13 @@ class Entity(EnvObject, abc.ABC):
def __init__(self, tile, **kwargs): def __init__(self, tile, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._status = None
self._tile = tile self._tile = tile
tile.enter(self) tile.enter(self)
def summarize_state(self) -> dict: def summarize_state(self) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y), return dict(name=str(self.name), x=int(self.x), y=int(self.y),
tile=str(self.tile.name), can_collide=bool(self.can_collide)) tile=str(self.tile.name), can_collide=bool(self.var_can_collide))
@abc.abstractmethod @abc.abstractmethod
def render(self): def render(self):

View File

@ -78,37 +78,37 @@ class EnvObject(Object):
return self.name return self.name
@property @property
def is_blocking_light(self): def var_is_blocking_light(self):
try: try:
return self._collection.is_blocking_light or False return self._collection.var_is_blocking_light or False
except AttributeError: except AttributeError:
return False return False
@property @property
def can_move(self): def var_can_move(self):
try: try:
return self._collection.can_move or False return self._collection.var_can_move or False
except AttributeError: except AttributeError:
return False return False
@property @property
def is_blocking_pos(self): def var_is_blocking_pos(self):
try: try:
return self._collection.is_blocking_pos or False return self._collection.var_is_blocking_pos or False
except AttributeError: except AttributeError:
return False return False
@property @property
def has_position(self): def var_has_position(self):
try: try:
return self._collection.has_position or False return self._collection.var_has_position or False
except AttributeError: except AttributeError:
return False return False
@property @property
def can_collide(self): def var_can_collide(self):
try: try:
return self._collection.can_collide or False return self._collection.var_can_collide or False
except AttributeError: except AttributeError:
return False return False

View File

@ -35,11 +35,11 @@ class GlobalPosition(BoundEntityMixin, EnvObject):
@property @property
def encoding(self): def encoding(self):
if self._normalized: if self._normalized:
return tuple(np.divide(self._bound_entity.pos, self._level_shape)) return tuple(np.divide(self._bound_entity.pos, self._shape))
else: else:
return self.bound_entity.pos return self.bound_entity.pos
def __init__(self, *args, normalized: bool = True, **kwargs): def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs) super(GlobalPosition, self).__init__(*args, **kwargs)
self._level_shape = math.sqrt(self.size)
self._normalized = normalized self._normalized = normalized
self._shape = level_shape

View File

@ -11,23 +11,23 @@ from marl_factory_grid.utils import helpers as h
class Floor(EnvObject): class Floor(EnvObject):
@property @property
def has_position(self): def var_has_position(self):
return True return True
@property @property
def can_collide(self): def var_can_collide(self):
return False return False
@property @property
def can_move(self): def var_can_move(self):
return False return False
@property @property
def is_blocking_pos(self): def var_is_blocking_pos(self):
return False return False
@property @property
def is_blocking_light(self): def var_is_blocking_light(self):
return False return False
@property @property
@ -51,7 +51,7 @@ class Floor(EnvObject):
@property @property
def guests_that_can_collide(self): def guests_that_can_collide(self):
return [x for x in self.guests if x.can_collide] return [x for x in self.guests if x.var_can_collide]
@property @property
def guests(self): def guests(self):
@ -67,7 +67,7 @@ class Floor(EnvObject):
@property @property
def is_blocked(self): def is_blocked(self):
return any([x.is_blocking_pos for x in self.guests]) return any([x.var_is_blocking_pos for x in self.guests])
def __init__(self, pos, **kwargs): def __init__(self, pos, **kwargs):
super(Floor, self).__init__(**kwargs) super(Floor, self).__init__(**kwargs)
@ -86,7 +86,7 @@ class Floor(EnvObject):
return bool(len(self._guests)) return bool(len(self._guests))
def enter(self, guest): def enter(self, guest):
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()): if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
self._guests.update({guest.name: guest}) self._guests.update({guest.name: guest})
return c.VALID return c.VALID
else: else:
@ -112,7 +112,7 @@ class Floor(EnvObject):
class Wall(Floor): class Wall(Floor):
@property @property
def can_collide(self): def var_can_collide(self):
return True return True
@property @property
@ -123,9 +123,9 @@ class Wall(Floor):
return RenderEntity(c.WALL, self.pos) return RenderEntity(c.WALL, self.pos)
@property @property
def is_blocking_pos(self): def var_is_blocking_pos(self):
return True return True
@property @property
def is_blocking_light(self): def var_is_blocking_light(self):
return True return True

View File

@ -19,7 +19,7 @@ from marl_factory_grid.utils.states import Gamestate
REC_TAC = 'rec_' REC_TAC = 'rec_'
class BaseFactory(gym.Env): class Factory(gym.Env):
@property @property
def action_space(self): def action_space(self):
@ -52,11 +52,15 @@ class BaseFactory(gym.Env):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.close() self.close()
def __init__(self, config_file: Union[str, PathLike]): def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
custom_level_path: Union[None, PathLike] = None):
self._config_file = config_file self._config_file = config_file
self.conf = FactoryConfigParser(self._config_file) self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
# Attribute Assignment # Attribute Assignment
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' if custom_level_path is not None:
self.level_filepath = Path(custom_level_path)
else:
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
self._renderer = None # expensive - don't use; unless required ! self._renderer = None # expensive - don't use; unless required !
parsed_entities = self.conf.load_entities() parsed_entities = self.conf.load_entities()
@ -90,7 +94,7 @@ class BaseFactory(gym.Env):
self.state.entities.add_item({c.AGENT: agents}) self.state.entities.add_item({c.AGENT: agents})
# All is set up, trigger additional init (after agent entity spawn etc) # All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state) self.state.rules.do_all_init(self.state, self.map)
# Observations # Observations
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
@ -144,7 +148,7 @@ class BaseFactory(gym.Env):
try: try:
done_reason = next(x for x in done_check_results if x.validity) done_reason = next(x for x in done_check_results if x.validity)
done = True done = True
self.state.print(f'Env done, Reason: {done_reason.name}.') self.state.print(f'Env done, Reason: {done_reason.identifier}.')
except StopIteration: except StopIteration:
done = False done = False

View File

@ -1,6 +1,6 @@
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin from marl_factory_grid.environment.groups.mixins import PositionMixin
from marl_factory_grid.environment.entity.agent import Agent
class Agents(PositionMixin, EnvObjects): class Agents(PositionMixin, EnvObjects):

View File

@ -5,10 +5,10 @@ from marl_factory_grid.environment.entity.object import EnvObject
class EnvObjects(Objects): class EnvObjects(Objects):
_entity = EnvObject _entity = EnvObject
is_blocking_light: bool = False var_is_blocking_light: bool = False
can_collide: bool = False var_can_collide: bool = False
has_position: bool = False var_has_position: bool = False
can_move: bool = False var_can_move: bool = False
@property @property
def encodings(self): def encodings(self):
@ -19,7 +19,7 @@ class EnvObjects(Objects):
self.size = size self.size = size
def add_item(self, item: EnvObject): def add_item(self, item: EnvObject):
assert self.has_position or (len(self) <= self.size) assert self.var_has_position or (len(self) <= self.size)
super(EnvObjects, self).add_item(item) super(EnvObjects, self).add_item(item)
return self return self

View File

@ -1,15 +1,19 @@
from typing import List
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.wall_floor import Floor
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
class PositionMixin: class PositionMixin:
_entity = Entity _entity = Entity
is_blocking_light: bool = True var_is_blocking_light: bool = True
can_collide: bool = True var_can_collide: bool = True
has_position: bool = True var_has_position: bool = True
def spawn(self, tiles: List[Floor]):
self.add_items([self._entity(tile) for tile in tiles])
def render(self): def render(self):
return [y for y in [x.render() for x in self] if y is not None] return [y for y in [x.render() for x in self] if y is not None]
@ -81,8 +85,8 @@ class IsBoundMixin:
class HasBoundedMixin: class HasBoundedMixin:
@property @property
def obs_names(self): def obs_pairs(self):
return [x.name for x in self] return [(x.name, x) for x in self]
def by_entity(self, entity): def by_entity(self, entity):
try: try:

View File

@ -4,6 +4,7 @@ from typing import List
import numpy as np import numpy as np
from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
class Objects: class Objects:
@ -116,12 +117,21 @@ class Objects:
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}[{dict(self._data)}]' return f'{self.__class__.__name__}[{dict(self._data)}]'
def spawn(self, n: int):
self.add_items([self._entity() for _ in range(n)])
return c.VALID
def despawn(self, items: List[Object]):
items = [items] if isinstance(items, Object) else items
for item in items:
del self[item]
def notify_change_pos(self, entity: object): def notify_change_pos(self, entity: object):
try: try:
self.pos_dict[entity.last_pos].remove(entity) self.pos_dict[entity.last_pos].remove(entity)
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
if entity.has_position: if entity.var_has_position:
try: try:
self.pos_dict[entity.pos].append(entity) self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError): except (ValueError, AttributeError):

View File

@ -2,10 +2,11 @@ from typing import List, Union
import numpy as np import numpy as np
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin
from marl_factory_grid.environment.entity.util import GlobalPosition from marl_factory_grid.environment.entity.util import GlobalPosition
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@ -44,7 +45,9 @@ class GlobalPositions(HasBoundedMixin, EnvObjects):
super(GlobalPositions, self).__init__(*args, **kwargs) super(GlobalPositions, self).__init__(*args, **kwargs)
class Zones(Objects): class ZonesOLD(Objects):
_entity = Zone
@property @property
def accounting_zones(self): def accounting_zones(self):

View File

@ -30,8 +30,8 @@ class Walls(PositionMixin, EnvObjects):
class Floors(Walls): class Floors(Walls):
_entity = Floor _entity = Floor
symbol = c.SYMBOL_FLOOR symbol = c.SYMBOL_FLOOR
is_blocking_light: bool = False var_is_blocking_light: bool = False
can_collide: bool = False var_can_collide: bool = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Floors, self).__init__(*args, **kwargs) super(Floors, self).__init__(*args, **kwargs)

View File

@ -17,7 +17,7 @@ class Rule(abc.ABC):
def __repr__(self): def __repr__(self):
return f'{self.name}' return f'{self.name}'
def on_init(self, state): def on_init(self, state, lvl_map):
return [] return []
def on_reset(self): def on_reset(self):
@ -42,7 +42,7 @@ class MaxStepsReached(Rule):
super().__init__() super().__init__()
self.max_steps = max_steps self.max_steps = max_steps
def on_init(self, state): def on_init(self, state, lvl_map):
pass pass
def on_check_done(self, state): def on_check_done(self, state):
@ -51,6 +51,20 @@ class MaxStepsReached(Rule):
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
class AssignGlobalPositions(Rule):
def __init__(self):
super().__init__()
def on_init(self, state, lvl_map):
from marl_factory_grid.environment.entity.util import GlobalPosition
for agent in state[c.AGENT]:
gp = GlobalPosition(lvl_map.level_shape)
gp.bind_to(agent)
state[c.GLOBALPOSITIONS].add_item(gp)
return []
class Collision(Rule): class Collision(Rule):
def __init__(self, done_at_collisions: bool = False): def __init__(self, done_at_collisions: bool = False):

View File

@ -0,0 +1,7 @@
from .batteries import *
from .clean_up import *
from .destinations import *
from .doors import *
from .items import *
from .machines import *
from .maintenance import *

View File

@ -8,4 +8,4 @@ WEST = 'west'
NORTHEAST = 'north_east' NORTHEAST = 'north_east'
SOUTHEAST = 'south_east' SOUTHEAST = 'south_east'
SOUTHWEST = 'south_west' SOUTHWEST = 'south_west'
NORTHWEST = 'north_west' NORTHWEST = 'north_west'

View File

@ -8,7 +8,7 @@ class TemplateRule(Rule):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TemplateRule, self).__init__(*args, **kwargs) super(TemplateRule, self).__init__(*args, **kwargs)
def on_init(self, state): def on_init(self, state, lvl_map):
pass pass
def tick_pre_step(self, state) -> List[TickResult]: def tick_pre_step(self, state) -> List[TickResult]:

View File

@ -0,0 +1,4 @@
from .actions import BtryCharge
from .entitites import ChargePod, Battery
from .groups import ChargePods, Batteries
from .rules import BtryDoneAtDischarge, Btry

View File

@ -13,18 +13,13 @@ class Batteries(HasBoundedMixin, EnvObjects):
def obs_tag(self): def obs_tag(self):
return self.__class__.__name__ return self.__class__.__name__
@property
def obs_pairs(self):
return [(x.name, x) for x in self]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Batteries, self).__init__(*args, **kwargs) super(Batteries, self).__init__(*args, **kwargs)
def spawn_batteries(self, agents, initial_charge_level): def spawn(self, agents, initial_charge_level):
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries) self.add_items(batteries)
class ChargePods(PositionMixin, EnvObjects): class ChargePods(PositionMixin, EnvObjects):
_entity = ChargePod _entity = ChargePod

View File

@ -13,8 +13,8 @@ class Btry(Rule):
self.per_action_costs = per_action_costs self.per_action_costs = per_action_costs
self.initial_charge = initial_charge self.initial_charge = initial_charge
def on_init(self, state): def on_init(self, state, lvl_map):
state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge) state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
def tick_pre_step(self, state) -> List[TickResult]: def tick_pre_step(self, state) -> List[TickResult]:
pass pass

View File

@ -0,0 +1,6 @@
from .actions import CleanUp
from .entitites import DirtPile
from .groups import DirtPiles
from .rule_respawn import DirtRespawnRule
from .rule_smear_on_move import DirtSmearOnMove
from .rule_done_on_all_clean import DirtAllCleanDone

View File

@ -7,6 +7,22 @@ from marl_factory_grid.modules.clean_up import constants as d
class DirtPile(Entity): class DirtPile(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property @property
def amount(self): def amount(self):
return self._amount return self._amount

View File

@ -31,7 +31,7 @@ class DirtPiles(PositionMixin, EnvObjects):
self.max_global_amount = max_global_amount self.max_global_amount = max_global_amount
self.max_local_amount = max_local_amount self.max_local_amount = max_local_amount
def spawn_dirt(self, then_dirty_tiles, amount) -> bool: def spawn(self, then_dirty_tiles, amount) -> bool:
if isinstance(then_dirty_tiles, Floor): if isinstance(then_dirty_tiles, Floor):
then_dirty_tiles = [then_dirty_tiles] then_dirty_tiles = [then_dirty_tiles]
for tile in then_dirty_tiles: for tile in then_dirty_tiles:
@ -57,7 +57,7 @@ class DirtPiles(PositionMixin, EnvObjects):
var = self.dirt_spawn_r_var var = self.dirt_spawn_r_var
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0)) new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount) return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount)
def __repr__(self): def __repr__(self):
s = super(DirtPiles, self).__repr__() s = super(DirtPiles, self).__repr__()

View File

@ -11,7 +11,7 @@ class DirtRespawnRule(Rule):
self.spawn_freq = spawn_freq self.spawn_freq = spawn_freq
self._next_dirt_spawn = spawn_freq self._next_dirt_spawn = spawn_freq
def on_init(self, state) -> str: def on_init(self, state, lvl_map) -> str:
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True) state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}' return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'

View File

@ -18,7 +18,7 @@ class DirtSmearOnMove(Rule):
if is_move(entity.state.identifier) and entity.state.validity == c.VALID: if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2): if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt): if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity, results.append(TickResult(identifier=self.name, entity=entity,
reward=0, validity=c.VALID)) reward=0, validity=c.VALID))
return results return results

View File

@ -0,0 +1,4 @@
from .actions import DestAction
from .entitites import Destination
from .groups import ReachedDestinations, Destinations
from .rules import DestinationDone, DestinationReach, DestinationSpawn

View File

@ -62,7 +62,7 @@ class DestinationSpawn(Rule):
self.n_dests = n_dests self.n_dests = n_dests
self.spawn_mode = spawn_mode self.spawn_mode = spawn_mode
def on_init(self, state): def on_init(self, state, lvl_map):
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self._dest_spawn_timer = self.spawn_frequency self._dest_spawn_timer = self.spawn_frequency
self.trigger_destination_spawn(self.n_dests, state) self.trigger_destination_spawn(self.n_dests, state)

View File

@ -0,0 +1,4 @@
from .actions import DoorUse
from .entitites import Door, DoorIndicator
from .groups import Doors
from .rule_door_auto_close import DoorAutoClose

View File

@ -1,10 +1,9 @@
from typing import Union from typing import Union
from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.doors import constants as d, rewards as r from marl_factory_grid.modules.doors import constants as d, rewards as r
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import ActionResult
class DoorUse(Action): class DoorUse(Action):

View File

@ -22,15 +22,15 @@ class DoorIndicator(Entity):
class Door(Entity): class Door(Entity):
@property @property
def is_blocking_pos(self): def var_is_blocking_pos(self):
return False if self.is_open else True return False if self.is_open else True
@property @property
def is_blocking_light(self): def var_is_blocking_light(self):
return False if self.is_open else True return False if self.is_open else True
@property @property
def can_collide(self): def var_can_collide(self):
return False if self.is_open else True return False if self.is_open else True
@property @property
@ -42,12 +42,14 @@ class Door(Entity):
return 'open' if self.is_open else 'closed' return 'open' if self.is_open else 'closed'
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs): def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
self._state = d.STATE_CLOSED self._status = d.STATE_CLOSED
super(Door, self).__init__(*args, **kwargs) super(Door, self).__init__(*args, **kwargs)
self.auto_close_interval = auto_close_interval self.auto_close_interval = auto_close_interval
self.time_to_close = 0 self.time_to_close = 0
if not closed_on_init: if not closed_on_init:
self._open() self._open()
else:
self._close()
if indicate_area: if indicate_area:
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor]) self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
@ -58,22 +60,22 @@ class Door(Entity):
@property @property
def is_closed(self): def is_closed(self):
return self._state == d.STATE_CLOSED return self._status == d.STATE_CLOSED
@property @property
def is_open(self): def is_open(self):
return self._state == d.STATE_OPEN return self._status == d.STATE_OPEN
@property @property
def status(self): def status(self):
return self._state return self._status
def render(self): def render(self):
name, state = 'door_open' if self.is_open else 'door_closed', 'blank' name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1) return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
def use(self): def use(self):
if self._state == d.STATE_OPEN: if self._status == d.STATE_OPEN:
self._close() self._close()
else: else:
self._open() self._open()
@ -90,8 +92,8 @@ class Door(Entity):
return c.NOT_VALID return c.NOT_VALID
def _open(self): def _open(self):
self._state = d.STATE_OPEN self._status = d.STATE_OPEN
self.time_to_close = self.auto_close_interval self.time_to_close = self.auto_close_interval
def _close(self): def _close(self):
self._state = d.STATE_CLOSED self._status = d.STATE_CLOSED

View File

@ -0,0 +1,32 @@
import random
from typing import List, Union
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import TickResult
class AgentSingleZonePlacementBeta(Rule):
def __init__(self):
super().__init__()
def on_init(self, state, lvl_map):
zones = state[c.ZONES]
n_zones = state[c.ZONES]
agents = state[c.AGENT]
if len(self.coordinates) == len(agents):
coordinates = self.coordinates
elif len(self.coordinates) > len(agents):
coordinates = random.choices(self.coordinates, k=len(agents))
else:
raise ValueError
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
for agent, tile in zip(agents, tiles):
agent.move(tile)
def tick_step(self, state):
return []
def tick_post_step(self, state) -> List[TickResult]:
return []

View File

@ -0,0 +1,4 @@
from .actions import ItemAction
from .entitites import Item, DropOffLocation
from .groups import DropOffLocations, Items, Inventory, Inventories
from .rules import ItemRules

View File

@ -8,6 +8,8 @@ from marl_factory_grid.modules.items import constants as i
class Item(Entity): class Item(Entity):
var_can_collide = False
def render(self): def render(self):
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
@ -38,6 +40,22 @@ class Item(Entity):
class DropOffLocation(Entity): class DropOffLocation(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def render(self): def render(self):
return RenderEntity(i.DROP_OFF, self.tile.pos) return RenderEntity(i.DROP_OFF, self.tile.pos)

View File

@ -17,15 +17,6 @@ class Items(PositionMixin, EnvObjects):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def spawn_items(self, tiles: List[Floor]):
items = [self._entity(tile) for tile in tiles]
self.add_items(items)
def despawn_items(self, items: List[Item]):
items = [items] if isinstance(items, Item) else items
for item in items:
del self[item]
class Inventory(IsBoundMixin, EnvObjects): class Inventory(IsBoundMixin, EnvObjects):
@ -58,11 +49,7 @@ class Inventory(IsBoundMixin, EnvObjects):
class Inventories(HasBoundedMixin, Objects): class Inventories(HasBoundedMixin, Objects):
_entity = Inventory _entity = Inventory
can_move = False var_can_move = False
@property
def obs_pairs(self):
return [(x.name, x) for x in self]
def __init__(self, size, *args, **kwargs): def __init__(self, size, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs) super(Inventories, self).__init__(*args, **kwargs)
@ -70,7 +57,7 @@ class Inventories(HasBoundedMixin, Objects):
self._obs = None self._obs = None
self._lazy_eval_transforms = [] self._lazy_eval_transforms = []
def spawn_inventories(self, agents): def spawn(self, agents):
inventories = [self._entity(agent, self.size,) inventories = [self._entity(agent, self.size,)
for _, agent in enumerate(agents)] for _, agent in enumerate(agents)]
self.add_items(inventories) self.add_items(inventories)

View File

@ -18,7 +18,7 @@ class ItemRules(Rule):
self.max_dropoff_storage_size = max_dropoff_storage_size self.max_dropoff_storage_size = max_dropoff_storage_size
self.n_locations = n_locations self.n_locations = n_locations
def on_init(self, state): def on_init(self, state, lvl_map):
self.trigger_drop_off_location_spawn(state) self.trigger_drop_off_location_spawn(state)
self._next_item_spawn = self.spawn_frequency self._next_item_spawn = self.spawn_frequency
self.trigger_inventory_spawn(state) self.trigger_inventory_spawn(state)
@ -42,7 +42,7 @@ class ItemRules(Rule):
def trigger_item_spawn(self, state): def trigger_item_spawn(self, state):
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))): if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns] empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
state[i.ITEM].spawn_items(empty_tiles) state[i.ITEM].spawn(empty_tiles)
self._next_item_spawn = self.spawn_frequency self._next_item_spawn = self.spawn_frequency
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}') state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
return len(empty_tiles) return len(empty_tiles)
@ -52,7 +52,7 @@ class ItemRules(Rule):
@staticmethod @staticmethod
def trigger_inventory_spawn(state): def trigger_inventory_spawn(state):
state[i.INVENTORY].spawn_inventories(state[c.AGENT]) state[i.INVENTORY].spawn(state[c.AGENT])
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
for item in list(state[i.ITEM].values()): for item in list(state[i.ITEM].values()):

View File

@ -0,0 +1,3 @@
from .entitites import Machine
from .groups import Machines
from .rules import MachineRule

View File

@ -0,0 +1,25 @@
from typing import Union
from marl_factory_grid.environment.actions import Action
from marl_factory_grid.utils.results import ActionResult
from marl_factory_grid.modules.machines import constants as m, rewards as r
from marl_factory_grid.environment import constants as c
class MachineAction(Action):
def __init__(self):
super().__init__(m.MACHINE_ACTION)
def do(self, entity, state) -> Union[None, ActionResult]:
if machine := state[m.MACHINES].by_pos(entity.pos):
if valid := machine.maintain():
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
else:
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)

View File

@ -2,6 +2,8 @@
MACHINES = 'Machines' MACHINES = 'Machines'
MACHINE = 'Machine' MACHINE = 'Machine'
MACHINE_ACTION = 'Maintain'
STATE_WORK = 'working' STATE_WORK = 'working'
STATE_IDLE = 'idling' STATE_IDLE = 'idling'
STATE_MAINTAIN = 'maintenance' STATE_MAINTAIN = 'maintenance'

View File

@ -2,27 +2,43 @@ from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.utils.render import RenderEntity
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.results import TickResult from marl_factory_grid.utils.results import TickResult
from marl_factory_grid.modules.machines import constants as m, rewards as r
from . import constants as m
class Machine(Entity): class Machine(Entity):
@property
def var_can_collide(self):
return False
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
@property @property
def encoding(self): def encoding(self):
return self._encodings[self.state] return self._encodings[self.status]
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs): def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
super(Machine, self).__init__(*args, **kwargs) super(Machine, self).__init__(*args, **kwargs)
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval}) self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval}) self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
self.state = m.STATE_IDLE self.status = m.STATE_IDLE
self.health = 100 self.health = 100
self._counter = 0 self._counter = 0
self.__delattr__('move')
def maintain(self): def maintain(self):
if self.state == m.STATE_WORK: if self.status == m.STATE_WORK:
return c.NOT_VALID return c.NOT_VALID
if self.health <= 98: if self.health <= 98:
self.health = 100 self.health = 100
@ -31,10 +47,10 @@ class Machine(Entity):
return c.NOT_VALID return c.NOT_VALID
def tick(self): def tick(self):
if self.state == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
return TickResult(self.name, c.VALID, r.NONE, self) return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
elif self.state == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
self.state = m.STATE_WORK self.status = m.STATE_WORK
self.reset_counter() self.reset_counter()
return None return None
elif self._counter: elif self._counter:
@ -42,12 +58,12 @@ class Machine(Entity):
self.health -= 1 self.health -= 1
return None return None
else: else:
self.state = m.STATE_WORK if self.state == m.STATE_IDLE else m.STATE_IDLE self.status = m.STATE_WORK if self.status == m.STATE_IDLE else m.STATE_IDLE
self.reset_counter() self.reset_counter()
return None return None
def reset_counter(self): def reset_counter(self):
self._counter = self._intervals[self.state] self._counter = self._intervals[self.status]
def render(self): def render(self):
return RenderEntity(m.MACHINE, self.pos) return RenderEntity(m.MACHINE, self.pos)

View File

@ -1,6 +1,7 @@
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin from marl_factory_grid.environment.groups.mixins import PositionMixin
from marl_factory_grid.modules.machines.entitites import Machine
from .entitites import Machine
class Machines(PositionMixin, EnvObjects): class Machines(PositionMixin, EnvObjects):

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.5 KiB

View File

@ -12,7 +12,7 @@ class MachineRule(Rule):
super(MachineRule, self).__init__() super(MachineRule, self).__init__()
self.n_machines = n_machines self.n_machines = n_machines
def on_init(self, state): def on_init(self, state, lvl_map):
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines] empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines]
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles) state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
@ -27,3 +27,9 @@ class MachineRule(Rule):
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
pass pass
class DoneOnBreakRule(Rule):
def on_check_done(self, state) -> List[DoneResult]:
pass

View File

@ -0,0 +1,2 @@
from .entities import Maintainer
from .groups import Maintainers

View File

@ -0,0 +1,3 @@
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!

View File

@ -0,0 +1,102 @@
import networkx as nx
import numpy as np
from ...algorithms.static.utils import points_to_graph
from ...environment import constants as c
from ...environment.actions import Action, ALL_BASEACTIONS
from ...environment.entity.entity import Entity
from ..doors import constants as do
from ..maintenance import constants as mi
from ...utils.helpers import MOVEMAP
from ...utils.render import RenderEntity
from ...utils.states import Gamestate
class Maintainer(Entity):
@property
def var_can_collide(self):
return True
@property
def var_can_move(self):
return False
@property
def var_is_blocking_light(self):
return False
@property
def var_has_position(self):
return True
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action = action
self.actions = [x() for x in ALL_BASEACTIONS]
self.objective = objective
self._path = None
self._next = []
self._last = []
self._last_serviced = 'None'
self._floortile_graph = points_to_graph(state[c.FLOOR].positions)
def tick(self, state):
if found_objective := state[self.objective].by_pos(self.pos):
if found_objective.name != self._last_serviced:
self.action.do(self, state)
self._last_serviced = found_objective.name
else:
action = self.get_move_action(state)
return action.do(self, state)
else:
action = self.get_move_action(state)
return action.do(self, state)
def get_move_action(self, state) -> Action:
if self._path is None or not self._path:
if not self._next:
self._next = list(state[self.objective].values())
self._last = []
self._last.append(self._next.pop())
self._path = self.calculate_route(self._last[-1])
if door := self._door_is_close():
if door.is_closed:
# Translate the action_object to an integer to have the same output as any other model
action = do.ACTION_DOOR_USE
else:
action = self._predict_move(state)
else:
action = self._predict_move(state)
# Translate the action_object to an integer to have the same output as any other model
try:
action_obj = next(x for x in self.actions if x.name == action)
except (StopIteration, UnboundLocalError):
print('Will not happen')
raise EnvironmentError
return action_obj
def calculate_route(self, entity):
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
return route[1:]
def _door_is_close(self):
try:
return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
except StopIteration:
return None
def _predict_move(self, state):
next_pos = self._path[0]
if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
action = c.NOOP
else:
next_pos = self._path.pop(0)
diff = np.subtract(next_pos, self.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
return action
def render(self):
return RenderEntity(mi.MAINTAINER, self.pos)

View File

@ -0,0 +1,27 @@
from typing import List
from .entities import Maintainer
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin
from ..machines.actions import MachineAction
from ...utils.render import RenderEntity
from ...utils.states import Gamestate
from ..machines import constants as mc
from . import constants as mi
class Maintainers(PositionMixin, EnvObjects):
_entity = Maintainer
var_can_collide = True
var_can_move = True
var_is_blocking_light = False
var_has_position = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def spawn(self, tiles: List[Floor], state: Gamestate):
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles])

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

View File

@ -0,0 +1 @@
MAINTAINER_COLLISION_REWARD = -5

View File

@ -0,0 +1,39 @@
from typing import List
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from . import rewards as r
from . import constants as M
from marl_factory_grid.utils.states import Gamestate
class MaintenanceRule(Rule):
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
super(MaintenanceRule, self).__init__(*args, **kwargs)
self.n_maintainer = n_maintainer
def on_init(self, state: Gamestate, lvl_map):
state[M.MAINTAINERS].spawn(state[c.FLOOR].empty_tiles[:self.n_maintainer], state)
pass
def tick_pre_step(self, state) -> List[TickResult]:
pass
def tick_step(self, state) -> List[TickResult]:
for maintainer in state[M.MAINTAINERS]:
maintainer.tick(state)
return []
def tick_post_step(self, state) -> List[TickResult]:
pass
def on_check_done(self, state) -> List[DoneResult]:
agents = list(state[c.AGENT].values())
m_pos = state[M.MAINTAINERS].positions
done_results = []
for agent in agents:
if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
reward=r.MAINTAINER_COLLISION_REWARD))
return done_results

View File

@ -0,0 +1,3 @@
from .entitites import Zone
from .groups import Zones
from .rules import AgentSingleZonePlacement

View File

@ -0,0 +1,4 @@
# Names / Identifiers
ZONES = 'Zones' # Identifier of Zone-objects and groups (groups).
ZONE = 'Zone' # -||-

View File

@ -0,0 +1,21 @@
import random
from typing import List
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.utils.render import RenderEntity
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.doors import constants as d
class Zone(Object):
def __init__(self, tiles: List[Floor], *args, **kwargs):
super(Zone, self).__init__(*args, **kwargs)
self.tiles = tiles
@property
def random_tile(self):
return random.choice(self.tiles)

View File

@ -0,0 +1,12 @@
from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.modules.zones import Zone
class Zones(Objects):
symbol = None
_entity = Zone
var_can_move = False
def __init__(self, *args, **kwargs):
super(Zones, self).__init__(*args, can_collide=True, **kwargs)

View File

@ -0,0 +1,33 @@
from random import choices
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.zones import Zone
from . import constants as z
class AgentSingleZonePlacement(Rule):
def __init__(self, n_zones=3):
super().__init__()
self.n_zones = n_zones
def on_init(self, state, lvl_map):
zones = []
for z_idx in range(1, self.n_zones):
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
assert len(zone_positions)
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
state[z.ZONES].add_items(zones)
n_agents = len(state[c.AGENT])
assert len(state[z.ZONES]) >= n_agents
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
for agent in state[c.AGENT]:
agent.move(state[z.ZONES][z_idxs.pop()].random_tile)
return []
def tick_step(self, state):
return []

View File

@ -10,10 +10,10 @@ def init():
ce = ConfigExplainer() ce = ConfigExplainer()
cwd = Path(os.getcwd()) cwd = Path(os.getcwd())
ce.save_all(cwd / 'full_config.yaml') ce.save_all(cwd / 'full_config.yaml')
template_path = Path(__file__) / 'marl_factory_grid' / 'modules' / '_template' template_path = Path(__file__).parent / 'modules' / '_template'
print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}') print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}')
print('-----------------------------') print('-----------------------------')
print(f'Copying Templates....') print(f'Copying Templates....')
shutil.copytree(template_path, cwd) shutil.copytree(template_path, cwd)
print(f'Templates copied to {template_path.resolve()}') print(f'Templates copied to {cwd}"/"{template_path.name}')
print(':wave:') print(':wave:')

View File

@ -18,11 +18,11 @@ class FactoryConfigParser(object):
default_entites = [] default_entites = []
default_rules = ['MaxStepsReached', 'Collision'] default_rules = ['MaxStepsReached', 'Collision']
default_actions = [c.MOVE8, c.NOOP] default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENTS] default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
self.config_path = Path(config_path) self.config_path = Path(config_path)
self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open()) self.config = yaml.safe_load(self.config_path.open())
self.do_record = False self.do_record = False
@ -69,12 +69,20 @@ class FactoryConfigParser(object):
for entity in entities: for entity in entities:
try: try:
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH folder_path = Path(__file__).parent.parent / DEFAULT_PATH
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path) entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e1:
try:
folder_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e2:
try:
folder_path = self.custom_modules_path
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e3:
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
entity_kwargs = self.entities.get(entity, {}) entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}}) entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
@ -92,7 +100,7 @@ class FactoryConfigParser(object):
parsed_actions = list() parsed_actions = list()
for action in actions: for action in actions:
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path) folder_path = Path(__file__).parent.parent / folder_path
try: try:
class_or_classes = locate_and_import_class(action, folder_path) class_or_classes = locate_and_import_class(action, folder_path)
except AttributeError: except AttributeError:
@ -124,12 +132,15 @@ class FactoryConfigParser(object):
rules.extend(x for x in self.rules if x != c.DEFAULTS) rules.extend(x for x in self.rules if x != c.DEFAULTS)
for rule in rules: for rule in rules:
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
try: try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path) rule_class = locate_and_import_class(rule, folder_path)
except AttributeError: except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path) try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
rule_kwargs = self.rules.get(rule, {}) rule_kwargs = self.rules.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
return rules_classes return rules_classes

View File

@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str):
grid = np.array(level) grid = np.array(level)
binary_grid = np.zeros(grid.shape, dtype=np.int8) binary_grid = np.zeros(grid.shape, dtype=np.int8)
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
return binary_grid return binary_grid
@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
for module_path in module_paths: for module_path in module_paths:
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos] module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts)) mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper()) all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor' and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque', 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict', 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject', 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any', ]])
'inspect']])
try: try:
model_class = mod.__getattribute__(class_name) model_class = mod.__getattribute__(class_name)
return model_class return model_class
except AttributeError: except AttributeError:
continue continue
raise AttributeError(f'Class "{class_name}" was not found!!!"\n' raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
f'Check the {folder_path.name} name.\n'
f'Possible Options are:\n{set(all_found_modules)}')

View File

@ -24,31 +24,40 @@ class LevelParser(object):
self.level_shape = level_array.shape self.level_shape = level_array.shape
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape) self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False):
level_array = h.one_hot_level(self._parsed_level, symbol)
if negate:
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
else:
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self): def do_init(self):
entities = Entities() entities = Entities()
# Walls # Walls
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL) walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
entities.add_items({c.WALL: walls}) entities.add_items({c.WALL: walls})
# Floor # Floor
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size) floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
entities.add_items({c.FLOOR: floor}) entities.add_items({c.FLOOR: floor})
# All other # All other
for es_name in self.e_p_dict: for es_name in self.e_p_dict:
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
if hasattr(e_class, 'symbol'): if hasattr(e_class, 'symbol') and e_class.symbol is not None:
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol) symbols = e_class.symbol
if np.any(level_array): if isinstance(symbols, (str, int, float)):
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(), symbols = [symbols]
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs for symbol in symbols:
) level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
else: if np.any(level_array):
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n' e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
f'Check your level file!') entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
)
else:
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
f'Check your level file!')
else: else:
e = e_class(self.size, **e_kwargs) e = e_class(self.size, **e_kwargs)
entities.add_items({e.name: e}) entities.add_items({e.name: e})

View File

@ -6,11 +6,10 @@ from typing import Dict, List
import numpy as np import numpy as np
from numba import njit from numba import njit
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.environment import constants as c
class OBSBuilder(object): class OBSBuilder(object):
@ -111,10 +110,10 @@ class OBSBuilder(object):
e = next(x for x in self.all_obs if l_name in x and agent.name in x) e = next(x for x in self.all_obs if l_name in x and agent.name in x)
except StopIteration: except StopIteration:
raise KeyError( raise KeyError(
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}') f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}')
try: try:
positional = e.has_position positional = e.var_has_position
except AttributeError: except AttributeError:
positional = False positional = False
if positional: if positional:
@ -172,7 +171,7 @@ class OBSBuilder(object):
obs_layers.append(combined.name) obs_layers.append(combined.name)
elif obs_str == c.OTHERS: elif obs_str == c.OTHERS:
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')]) obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
elif obs_str == c.AGENTS: elif obs_str == c.AGENT:
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')]) obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
else: else:
obs_layers.append(obs_str) obs_layers.append(obs_str)
@ -222,7 +221,7 @@ class RayCaster:
entities_hit = entities.pos_dict[(x, y)] entities_hit = entities.pos_dict[(x, y)]
hits = self.ray_block_cache(cache_blocking, hits = self.ray_block_cache(cache_blocking,
(x, y), (x, y),
lambda: any(e.is_blocking_light for e in entities_hit), lambda: any(e.var_is_blocking_light for e in entities_hit),
entities) entities)
try: try:
@ -237,8 +236,8 @@ class RayCaster:
self.ray_block_cache( self.ray_block_cache(
cache_blocking, cache_blocking,
key, key,
# lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light), # lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light),
lambda: any(e.is_blocking_light for e in entities.pos_dict[key]), lambda: any(e.var_is_blocking_light for e in entities.pos_dict[key]),
entities) entities)
for key in ((x, y-cy), (x-cx, y)) for key in ((x, y-cy), (x-cx, y))
]) if (cx != 0 and cy != 0) else False ]) if (cx != 0 and cy != 0) else False

View File

@ -27,13 +27,13 @@ class Renderer:
BG_COLOR = (178, 190, 195) # (99, 110, 114) BG_COLOR = (178, 190, 195) # (99, 110, 114)
WHITE = (223, 230, 233) # (200, 200, 200) WHITE = (223, 230, 233) # (200, 200, 200)
AGENT_VIEW_COLOR = (9, 132, 227) AGENT_VIEW_COLOR = (9, 132, 227)
ASSETS = Path(__file__).parent.parent / 'assets' ASSETS = Path(__file__).parent.parent
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
lvl_padded_shape: Union[Tuple[int, int], None] = None, lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7, cell_size: int = 40, fps: int = 7,
grid_lines: bool = True, view_radius: int = 2): grid_lines: bool = True, view_radius: int = 2):
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
self.cell_size = cell_size self.cell_size = cell_size
@ -44,7 +44,7 @@ class Renderer:
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size) self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
self.screen = pygame.display.set_mode(self.screen_size) self.screen = pygame.display.set_mode(self.screen_size)
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png')) assets = list(self.ASSETS.rglob('*.png'))
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
self.fill_bg() self.fill_bg()

View File

@ -1,8 +1,6 @@
from typing import Union from typing import Union
from dataclasses import dataclass from dataclasses import dataclass
from marl_factory_grid.environment.entity.entity import Entity
TYPE_VALUE = 'value' TYPE_VALUE = 'value'
TYPE_REWARD = 'reward' TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD] types = [TYPE_VALUE, TYPE_REWARD]
@ -20,7 +18,7 @@ class Result:
validity: bool validity: bool
reward: Union[float, None] = None reward: Union[float, None] = None
value: Union[float, None] = None value: Union[float, None] = None
entity: Union[Entity, None] = None entity: None = None
def get_infos(self): def get_infos(self):
n = self.entity.name if self.entity is not None else "Global" n = self.entity.name if self.entity is not None else "Global"

View File

@ -2,10 +2,11 @@ from typing import List, Dict
import numpy as np import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.wall_floor import Floor from marl_factory_grid.environment.entity.wall_floor import Floor
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
from marl_factory_grid.environment import constants as c
class StepRules: class StepRules:
@ -26,9 +27,9 @@ class StepRules:
self.rules.append(item) self.rules.append(item)
return True return True
def do_all_init(self, state): def do_all_init(self, state, lvl_map):
for rule in self.rules: for rule in self.rules:
if rule_init_printline := rule.on_init(state): if rule_init_printline := rule.on_init(state, lvl_map):
state.print(rule_init_printline) state.print(rule_init_printline)
return c.VALID return c.VALID
@ -58,7 +59,7 @@ class Gamestate(object):
@property @property
def moving_entites(self): def moving_entites(self):
return [y for x in self.entities for y in x if x.can_move] return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False): def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
self.entities = entitites self.entities = entitites
@ -107,6 +108,6 @@ class Gamestate(object):
def get_all_tiles_with_collisions(self) -> List[Floor]: def get_all_tiles_with_collisions(self) -> List[Floor]:
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items() tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
if sum([x.can_collide for x in e]) > 1] if sum([x.var_can_collide for x in e]) > 1]
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1] # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
return tiles return tiles

View File

@ -15,7 +15,7 @@ ENTITIES = 'Objects'
OBSERVATIONS = 'Observations' OBSERVATIONS = 'Observations'
RULES = 'Rule' RULES = 'Rule'
ASSETS = 'Assets' ASSETS = 'Assets'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ] 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]

View File

@ -1,21 +1,6 @@
import gymnasium as gym import gymnasium as gym
class EnvCombiner(object):
def __init__(self, *envs_cls):
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
@staticmethod
def combine_cls(name, *envs_cls):
return type(name, envs_cls, {})
def build(self):
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
return self.combine_cls(name, tuple(self._env_dict.values()))
class MarlFrameStack(gym.ObservationWrapper): class MarlFrameStack(gym.ObservationWrapper):
"""todo @romue404""" """todo @romue404"""
def __init__(self, env): def __init__(self, env):

View File

@ -3,7 +3,7 @@ from pathlib import Path
import yaml import yaml
from marl_factory_grid.environment.factory import BaseFactory from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.logging.envmonitor import EnvMonitor from marl_factory_grid.logging.envmonitor import EnvMonitor
from marl_factory_grid.logging.recorder import EnvRecorder from marl_factory_grid.logging.recorder import EnvRecorder
@ -41,7 +41,7 @@ if __name__ == '__main__':
pass pass
# Init Env # Init Env
with BaseFactory(**env_kwargs) as env: with Factory(**env_kwargs) as env:
env = EnvMonitor(env) env = EnvMonitor(env)
env = EnvRecorder(env) if record else env env = EnvRecorder(env) if record else env
obs_shape = env.observation_space.shape obs_shape = env.observation_space.shape

View File

@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text()
setup(name='Marl-Factory-Grid', setup(name='Marl-Factory-Grid',
version='0.0.11', version='0.0.12',
description='A framework to research MARL agents in various setings.', description='A framework to research MARL agents in various setings.',
author='Steffen Illium', author='Steffen Illium',
author_email='steffen.illium@ifi.lmu.de', author_email='steffen.illium@ifi.lmu.de',