mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
Machines
This commit is contained in:
@ -1,6 +1,6 @@
|
|||||||
from .environment.factory import BaseFactory
|
from .environment import *
|
||||||
from .environment.factory import OBSBuilder
|
from .modules import *
|
||||||
|
from .utils import *
|
||||||
from .utils.tools import ConfigExplainer
|
|
||||||
|
|
||||||
from .quickstart import init
|
from .quickstart import init
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import itertools
|
|
||||||
from random import choice
|
from random import choice
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
from networkx.algorithms.approximation import traveling_salesman as tsp
|
from networkx.algorithms.approximation import traveling_salesman as tsp
|
||||||
|
|
||||||
|
from marl_factory_grid.algorithms.static.utils import points_to_graph
|
||||||
from marl_factory_grid.modules.doors import constants as do
|
from marl_factory_grid.modules.doors import constants as do
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.helpers import MOVEMAP
|
from marl_factory_grid.utils.helpers import MOVEMAP
|
||||||
@ -15,41 +14,6 @@ from abc import abstractmethod, ABC
|
|||||||
future_planning = 7
|
future_planning = 7
|
||||||
|
|
||||||
|
|
||||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
|
||||||
"""
|
|
||||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
|
||||||
There are three combinations of settings:
|
|
||||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
|
||||||
Allow only manhattan: Distance(a, b) == 1
|
|
||||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
|
||||||
|
|
||||||
|
|
||||||
:param coordiniates_or_tiles: A set of coordinates.
|
|
||||||
:type coordiniates_or_tiles: Tiles
|
|
||||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
|
||||||
:type: bool
|
|
||||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
|
||||||
:type: bool
|
|
||||||
|
|
||||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
|
||||||
:rtype: nx.Graph
|
|
||||||
"""
|
|
||||||
assert allow_euclidean_connections or allow_manhattan_connections
|
|
||||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
|
||||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
|
||||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
|
||||||
graph = nx.Graph()
|
|
||||||
for a, b in possible_connections:
|
|
||||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
|
||||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
|
||||||
graph.add_edge(a, b)
|
|
||||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
|
||||||
graph.add_edge(a, b)
|
|
||||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
|
||||||
graph.add_edge(a, b)
|
|
||||||
return graph
|
|
||||||
|
|
||||||
|
|
||||||
class TSPBaseAgent(ABC):
|
class TSPBaseAgent(ABC):
|
||||||
|
|
||||||
def __init__(self, state, agent_i, static_problem: bool = True):
|
def __init__(self, state, agent_i, static_problem: bool = True):
|
||||||
|
39
marl_factory_grid/algorithms/static/utils.py
Normal file
39
marl_factory_grid/algorithms/static/utils.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||||
|
"""
|
||||||
|
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||||
|
There are three combinations of settings:
|
||||||
|
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||||
|
Allow only manhattan: Distance(a, b) == 1
|
||||||
|
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||||
|
|
||||||
|
|
||||||
|
:param coordiniates_or_tiles: A set of coordinates.
|
||||||
|
:type coordiniates_or_tiles: Tiles
|
||||||
|
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||||
|
:type: bool
|
||||||
|
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||||
|
:type: bool
|
||||||
|
|
||||||
|
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||||
|
:rtype: nx.Graph
|
||||||
|
"""
|
||||||
|
assert allow_euclidean_connections or allow_manhattan_connections
|
||||||
|
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||||
|
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||||
|
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||||
|
graph = nx.Graph()
|
||||||
|
for a, b in possible_connections:
|
||||||
|
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||||
|
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||||
|
graph.add_edge(a, b)
|
||||||
|
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||||
|
graph.add_edge(a, b)
|
||||||
|
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||||
|
graph.add_edge(a, b)
|
||||||
|
return graph
|
@ -1,68 +1,89 @@
|
|||||||
---
|
|
||||||
General:
|
|
||||||
level_name: rooms
|
|
||||||
env_seed: 69
|
|
||||||
verbose: !!bool False
|
|
||||||
pomdp_r: 5
|
|
||||||
individual_rewards: !!bool True
|
|
||||||
|
|
||||||
Entities:
|
|
||||||
Defaults: {}
|
|
||||||
DirtPiles:
|
|
||||||
initial_dirt_ratio: 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
|
||||||
dirt_spawn_r_var: 0.05 # How much does the dirt spawn amount vary?
|
|
||||||
initial_amount: 3
|
|
||||||
max_local_amount: 5 # Max dirt amount per tile.
|
|
||||||
max_global_amount: 20 # Max dirt amount in the whole environment.
|
|
||||||
Doors:
|
|
||||||
closed_on_init: True
|
|
||||||
auto_close_interval: 10
|
|
||||||
indicate_area: False
|
|
||||||
Agents:
|
Agents:
|
||||||
Wolfgang:
|
Wolfgang:
|
||||||
Actions:
|
Actions:
|
||||||
- Move8
|
- Noop
|
||||||
- Noop
|
- BtryCharge
|
||||||
- DoorUse
|
- CleanUp
|
||||||
- CleanUp
|
- DestAction
|
||||||
|
- DoorUse
|
||||||
|
- ItemAction
|
||||||
|
- Move8
|
||||||
Observations:
|
Observations:
|
||||||
- Self
|
- Combined:
|
||||||
- Placeholder
|
- Other
|
||||||
- Walls
|
- Walls
|
||||||
- DirtPiles
|
- GlobalPosition
|
||||||
- Placeholder
|
- Battery
|
||||||
- Doors
|
- ChargePods
|
||||||
- Doors
|
- DirtPiles
|
||||||
Björn:
|
- Destinations
|
||||||
Actions:
|
- Doors
|
||||||
# Move4, Noop
|
- Items
|
||||||
- Move4
|
- Inventory
|
||||||
- DoorUse
|
- DropOffLocations
|
||||||
- CleanUp
|
- Machines
|
||||||
Observations:
|
- Maintainers
|
||||||
- Defaults
|
Entities:
|
||||||
- Combined
|
Batteries: {}
|
||||||
Jürgen:
|
ChargePods: {}
|
||||||
Actions:
|
Destinations: {}
|
||||||
# Move4, Noop
|
DirtPiles:
|
||||||
- Defaults
|
clean_amount: 1
|
||||||
- DoorUse
|
dirt_spawn_r_var: 0.1
|
||||||
- CleanUp
|
initial_amount: 2
|
||||||
Observations:
|
initial_dirt_ratio: 0.05
|
||||||
- Walls
|
max_global_amount: 20
|
||||||
- Placeholder
|
max_local_amount: 5
|
||||||
- Agent[Björn]
|
Doors: {}
|
||||||
|
DropOffLocations: {}
|
||||||
|
GlobalPositions: {}
|
||||||
|
Inventories: {}
|
||||||
|
Items: {}
|
||||||
|
Machines: {}
|
||||||
|
Maintainers: {}
|
||||||
|
Zones: {}
|
||||||
|
ReachedDestinations: {}
|
||||||
|
|
||||||
|
General:
|
||||||
|
env_seed: 69
|
||||||
|
individual_rewards: true
|
||||||
|
level_name: large
|
||||||
|
pomdp_r: 3
|
||||||
|
verbose: false
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
Defaults: {}
|
Btry:
|
||||||
|
initial_charge: 0.8
|
||||||
|
per_action_costs: 0.02
|
||||||
|
BtryDoneAtDischarge: {}
|
||||||
Collision:
|
Collision:
|
||||||
done_at_collisions: !!bool False
|
done_at_collisions: false
|
||||||
DirtRespawnRule:
|
AssignGlobalPositions: {}
|
||||||
spawn_freq: 5
|
DestinationDone: {}
|
||||||
DirtSmearOnMove:
|
DestinationReach:
|
||||||
smear_amount: 0.12
|
n_dests: 1
|
||||||
DoorAutoClose: {}
|
tiles: null
|
||||||
|
DestinationSpawn:
|
||||||
|
n_dests: 1
|
||||||
|
spawn_frequency: 5
|
||||||
|
spawn_mode: GROUPED
|
||||||
DirtAllCleanDone: {}
|
DirtAllCleanDone: {}
|
||||||
Assets:
|
DirtRespawnRule:
|
||||||
- Defaults
|
spawn_freq: 15
|
||||||
- Dirt
|
DirtSmearOnMove:
|
||||||
- Doors
|
smear_amount: 0.2
|
||||||
|
DoorAutoClose:
|
||||||
|
close_frequency: 10
|
||||||
|
ItemRules:
|
||||||
|
max_dropoff_storage_size: 0
|
||||||
|
n_items: 5
|
||||||
|
n_locations: 5
|
||||||
|
spawn_frequency: 15
|
||||||
|
MachineRule:
|
||||||
|
n_machines: 2
|
||||||
|
MaintenanceRule:
|
||||||
|
n_maintainer: 1
|
||||||
|
MaxStepsReached:
|
||||||
|
max_steps: 500
|
||||||
|
# AgentSingleZonePlacement:
|
||||||
|
# n_zones: 4
|
||||||
|
@ -98,3 +98,5 @@ class NorthWest(Move):
|
|||||||
Move4 = [North, East, South, West]
|
Move4 = [North, East, South, West]
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
||||||
|
|
||||||
|
ALL_BASEACTIONS = Move8 + [Noop]
|
||||||
|
@ -9,15 +9,13 @@ WALL = 'Wall' # Identifier of Wall-objects and
|
|||||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||||
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
|
|
||||||
OTHERS = 'Other'
|
OTHERS = 'Other'
|
||||||
COMBINED = 'Combined'
|
COMBINED = 'Combined'
|
||||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||||
|
|
||||||
|
|
||||||
# Attributes
|
# Attributes
|
||||||
IS_BLOCKING_LIGHT = 'is_blocking_light'
|
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||||
HAS_POSITION = 'has_position'
|
HAS_POSITION = 'var_has_position'
|
||||||
HAS_NO_POSITION = 'has_no_position'
|
HAS_NO_POSITION = 'has_no_position'
|
||||||
ALL = 'All'
|
ALL = 'All'
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
from marl_factory_grid.environment.actions import Action
|
from marl_factory_grid.environment.actions import Action
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.render import RenderEntity
|
||||||
@ -8,6 +7,8 @@ from marl_factory_grid.utils import renderer
|
|||||||
from marl_factory_grid.utils.helpers import is_move
|
from marl_factory_grid.utils.helpers import is_move
|
||||||
from marl_factory_grid.utils.results import ActionResult, Result
|
from marl_factory_grid.utils.results import ActionResult, Result
|
||||||
|
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
|
|
||||||
class Agent(Entity):
|
class Agent(Entity):
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ class Agent(Entity):
|
|||||||
return self._observations
|
return self._observations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def var_can_collide(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def step_result(self):
|
def step_result(self):
|
||||||
|
@ -1,15 +1,20 @@
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from .. import constants as c
|
||||||
from marl_factory_grid.environment.entity.object import EnvObject
|
from .object import EnvObject
|
||||||
from marl_factory_grid.utils.render import RenderEntity
|
from ...utils.render import RenderEntity
|
||||||
|
from ...utils.results import ActionResult
|
||||||
|
|
||||||
|
|
||||||
class Entity(EnvObject, abc.ABC):
|
class Entity(EnvObject, abc.ABC):
|
||||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_position(self):
|
def state(self):
|
||||||
|
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_has_position(self):
|
||||||
return self.pos != c.VALUE_NO_POS
|
return self.pos != c.VALUE_NO_POS
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -64,12 +69,13 @@ class Entity(EnvObject, abc.ABC):
|
|||||||
|
|
||||||
def __init__(self, tile, **kwargs):
|
def __init__(self, tile, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self._status = None
|
||||||
self._tile = tile
|
self._tile = tile
|
||||||
tile.enter(self)
|
tile.enter(self)
|
||||||
|
|
||||||
def summarize_state(self) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
tile=str(self.tile.name), can_collide=bool(self.var_can_collide))
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def render(self):
|
def render(self):
|
||||||
|
@ -78,37 +78,37 @@ class EnvObject(Object):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
try:
|
try:
|
||||||
return self._collection.is_blocking_light or False
|
return self._collection.var_is_blocking_light or False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_move(self):
|
def var_can_move(self):
|
||||||
try:
|
try:
|
||||||
return self._collection.can_move or False
|
return self._collection.var_can_move or False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_pos(self):
|
def var_is_blocking_pos(self):
|
||||||
try:
|
try:
|
||||||
return self._collection.is_blocking_pos or False
|
return self._collection.var_is_blocking_pos or False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_position(self):
|
def var_has_position(self):
|
||||||
try:
|
try:
|
||||||
return self._collection.has_position or False
|
return self._collection.var_has_position or False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def var_can_collide(self):
|
||||||
try:
|
try:
|
||||||
return self._collection.can_collide or False
|
return self._collection.var_can_collide or False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -35,11 +35,11 @@ class GlobalPosition(BoundEntityMixin, EnvObject):
|
|||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
if self._normalized:
|
if self._normalized:
|
||||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
return tuple(np.divide(self._bound_entity.pos, self._shape))
|
||||||
else:
|
else:
|
||||||
return self.bound_entity.pos
|
return self.bound_entity.pos
|
||||||
|
|
||||||
def __init__(self, *args, normalized: bool = True, **kwargs):
|
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
|
||||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||||
self._level_shape = math.sqrt(self.size)
|
|
||||||
self._normalized = normalized
|
self._normalized = normalized
|
||||||
|
self._shape = level_shape
|
||||||
|
@ -11,23 +11,23 @@ from marl_factory_grid.utils import helpers as h
|
|||||||
class Floor(EnvObject):
|
class Floor(EnvObject):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_position(self):
|
def var_has_position(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def var_can_collide(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_move(self):
|
def var_can_move(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_pos(self):
|
def var_is_blocking_pos(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -51,7 +51,7 @@ class Floor(EnvObject):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def guests_that_can_collide(self):
|
def guests_that_can_collide(self):
|
||||||
return [x for x in self.guests if x.can_collide]
|
return [x for x in self.guests if x.var_can_collide]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def guests(self):
|
def guests(self):
|
||||||
@ -67,7 +67,7 @@ class Floor(EnvObject):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocked(self):
|
def is_blocked(self):
|
||||||
return any([x.is_blocking_pos for x in self.guests])
|
return any([x.var_is_blocking_pos for x in self.guests])
|
||||||
|
|
||||||
def __init__(self, pos, **kwargs):
|
def __init__(self, pos, **kwargs):
|
||||||
super(Floor, self).__init__(**kwargs)
|
super(Floor, self).__init__(**kwargs)
|
||||||
@ -86,7 +86,7 @@ class Floor(EnvObject):
|
|||||||
return bool(len(self._guests))
|
return bool(len(self._guests))
|
||||||
|
|
||||||
def enter(self, guest):
|
def enter(self, guest):
|
||||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
|
if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||||
self._guests.update({guest.name: guest})
|
self._guests.update({guest.name: guest})
|
||||||
return c.VALID
|
return c.VALID
|
||||||
else:
|
else:
|
||||||
@ -112,7 +112,7 @@ class Floor(EnvObject):
|
|||||||
class Wall(Floor):
|
class Wall(Floor):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def var_can_collide(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -123,9 +123,9 @@ class Wall(Floor):
|
|||||||
return RenderEntity(c.WALL, self.pos)
|
return RenderEntity(c.WALL, self.pos)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_pos(self):
|
def var_is_blocking_pos(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return True
|
return True
|
||||||
|
@ -19,7 +19,7 @@ from marl_factory_grid.utils.states import Gamestate
|
|||||||
REC_TAC = 'rec_'
|
REC_TAC = 'rec_'
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(gym.Env):
|
class Factory(gym.Env):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
@ -52,11 +52,15 @@ class BaseFactory(gym.Env):
|
|||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __init__(self, config_file: Union[str, PathLike]):
|
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
|
||||||
|
custom_level_path: Union[None, PathLike] = None):
|
||||||
self._config_file = config_file
|
self._config_file = config_file
|
||||||
self.conf = FactoryConfigParser(self._config_file)
|
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
|
||||||
# Attribute Assignment
|
# Attribute Assignment
|
||||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
if custom_level_path is not None:
|
||||||
|
self.level_filepath = Path(custom_level_path)
|
||||||
|
else:
|
||||||
|
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||||
self._renderer = None # expensive - don't use; unless required !
|
self._renderer = None # expensive - don't use; unless required !
|
||||||
|
|
||||||
parsed_entities = self.conf.load_entities()
|
parsed_entities = self.conf.load_entities()
|
||||||
@ -90,7 +94,7 @@ class BaseFactory(gym.Env):
|
|||||||
self.state.entities.add_item({c.AGENT: agents})
|
self.state.entities.add_item({c.AGENT: agents})
|
||||||
|
|
||||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||||
self.state.rules.do_all_init(self.state)
|
self.state.rules.do_all_init(self.state, self.map)
|
||||||
|
|
||||||
# Observations
|
# Observations
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
@ -144,7 +148,7 @@ class BaseFactory(gym.Env):
|
|||||||
try:
|
try:
|
||||||
done_reason = next(x for x in done_check_results if x.validity)
|
done_reason = next(x for x in done_check_results if x.validity)
|
||||||
done = True
|
done = True
|
||||||
self.state.print(f'Env done, Reason: {done_reason.name}.')
|
self.state.print(f'Env done, Reason: {done_reason.identifier}.')
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
from marl_factory_grid.environment.entity.agent import Agent
|
||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
|
||||||
|
|
||||||
|
|
||||||
class Agents(PositionMixin, EnvObjects):
|
class Agents(PositionMixin, EnvObjects):
|
||||||
|
@ -5,10 +5,10 @@ from marl_factory_grid.environment.entity.object import EnvObject
|
|||||||
class EnvObjects(Objects):
|
class EnvObjects(Objects):
|
||||||
|
|
||||||
_entity = EnvObject
|
_entity = EnvObject
|
||||||
is_blocking_light: bool = False
|
var_is_blocking_light: bool = False
|
||||||
can_collide: bool = False
|
var_can_collide: bool = False
|
||||||
has_position: bool = False
|
var_has_position: bool = False
|
||||||
can_move: bool = False
|
var_can_move: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encodings(self):
|
def encodings(self):
|
||||||
@ -19,7 +19,7 @@ class EnvObjects(Objects):
|
|||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def add_item(self, item: EnvObject):
|
def add_item(self, item: EnvObject):
|
||||||
assert self.has_position or (len(self) <= self.size)
|
assert self.var_has_position or (len(self) <= self.size)
|
||||||
super(EnvObjects, self).add_item(item)
|
super(EnvObjects, self).add_item(item)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -1,15 +1,19 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
|
||||||
class PositionMixin:
|
class PositionMixin:
|
||||||
|
|
||||||
_entity = Entity
|
_entity = Entity
|
||||||
is_blocking_light: bool = True
|
var_is_blocking_light: bool = True
|
||||||
can_collide: bool = True
|
var_can_collide: bool = True
|
||||||
has_position: bool = True
|
var_has_position: bool = True
|
||||||
|
|
||||||
|
def spawn(self, tiles: List[Floor]):
|
||||||
|
self.add_items([self._entity(tile) for tile in tiles])
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return [y for y in [x.render() for x in self] if y is not None]
|
return [y for y in [x.render() for x in self] if y is not None]
|
||||||
@ -81,8 +85,8 @@ class IsBoundMixin:
|
|||||||
class HasBoundedMixin:
|
class HasBoundedMixin:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def obs_names(self):
|
def obs_pairs(self):
|
||||||
return [x.name for x in self]
|
return [(x.name, x) for x in self]
|
||||||
|
|
||||||
def by_entity(self, entity):
|
def by_entity(self, entity):
|
||||||
try:
|
try:
|
||||||
|
@ -4,6 +4,7 @@ from typing import List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.object import Object
|
from marl_factory_grid.environment.entity.object import Object
|
||||||
|
import marl_factory_grid.environment.constants as c
|
||||||
|
|
||||||
|
|
||||||
class Objects:
|
class Objects:
|
||||||
@ -116,12 +117,21 @@ class Objects:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||||
|
|
||||||
|
def spawn(self, n: int):
|
||||||
|
self.add_items([self._entity() for _ in range(n)])
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def despawn(self, items: List[Object]):
|
||||||
|
items = [items] if isinstance(items, Object) else items
|
||||||
|
for item in items:
|
||||||
|
del self[item]
|
||||||
|
|
||||||
def notify_change_pos(self, entity: object):
|
def notify_change_pos(self, entity: object):
|
||||||
try:
|
try:
|
||||||
self.pos_dict[entity.last_pos].remove(entity)
|
self.pos_dict[entity.last_pos].remove(entity)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
pass
|
pass
|
||||||
if entity.has_position:
|
if entity.var_has_position:
|
||||||
try:
|
try:
|
||||||
self.pos_dict[entity.pos].append(entity)
|
self.pos_dict[entity.pos].append(entity)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
|
@ -2,10 +2,11 @@ from typing import List, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
|
||||||
from marl_factory_grid.environment.groups.objects import Objects
|
|
||||||
from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin
|
|
||||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
||||||
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
|
from marl_factory_grid.modules.zones import Zone
|
||||||
from marl_factory_grid.utils import helpers as h
|
from marl_factory_grid.utils import helpers as h
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
@ -44,7 +45,9 @@ class GlobalPositions(HasBoundedMixin, EnvObjects):
|
|||||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Zones(Objects):
|
class ZonesOLD(Objects):
|
||||||
|
|
||||||
|
_entity = Zone
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accounting_zones(self):
|
def accounting_zones(self):
|
||||||
|
@ -30,8 +30,8 @@ class Walls(PositionMixin, EnvObjects):
|
|||||||
class Floors(Walls):
|
class Floors(Walls):
|
||||||
_entity = Floor
|
_entity = Floor
|
||||||
symbol = c.SYMBOL_FLOOR
|
symbol = c.SYMBOL_FLOOR
|
||||||
is_blocking_light: bool = False
|
var_is_blocking_light: bool = False
|
||||||
can_collide: bool = False
|
var_can_collide: bool = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Floors, self).__init__(*args, **kwargs)
|
super(Floors, self).__init__(*args, **kwargs)
|
||||||
|
@ -17,7 +17,7 @@ class Rule(abc.ABC):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.name}'
|
return f'{self.name}'
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def on_reset(self):
|
def on_reset(self):
|
||||||
@ -42,7 +42,7 @@ class MaxStepsReached(Rule):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_check_done(self, state):
|
def on_check_done(self, state):
|
||||||
@ -51,6 +51,20 @@ class MaxStepsReached(Rule):
|
|||||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||||
|
|
||||||
|
|
||||||
|
class AssignGlobalPositions(Rule):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||||
|
for agent in state[c.AGENT]:
|
||||||
|
gp = GlobalPosition(lvl_map.level_shape)
|
||||||
|
gp.bind_to(agent)
|
||||||
|
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class Collision(Rule):
|
class Collision(Rule):
|
||||||
|
|
||||||
def __init__(self, done_at_collisions: bool = False):
|
def __init__(self, done_at_collisions: bool = False):
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
from .batteries import *
|
||||||
|
from .clean_up import *
|
||||||
|
from .destinations import *
|
||||||
|
from .doors import *
|
||||||
|
from .items import *
|
||||||
|
from .machines import *
|
||||||
|
from .maintenance import *
|
||||||
|
@ -8,4 +8,4 @@ WEST = 'west'
|
|||||||
NORTHEAST = 'north_east'
|
NORTHEAST = 'north_east'
|
||||||
SOUTHEAST = 'south_east'
|
SOUTHEAST = 'south_east'
|
||||||
SOUTHWEST = 'south_west'
|
SOUTHWEST = 'south_west'
|
||||||
NORTHWEST = 'north_west'
|
NORTHWEST = 'north_west'
|
||||||
|
@ -8,7 +8,7 @@ class TemplateRule(Rule):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TemplateRule, self).__init__(*args, **kwargs)
|
super(TemplateRule, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
from .actions import BtryCharge
|
||||||
|
from .entitites import ChargePod, Battery
|
||||||
|
from .groups import ChargePods, Batteries
|
||||||
|
from .rules import BtryDoneAtDischarge, Btry
|
||||||
|
@ -13,18 +13,13 @@ class Batteries(HasBoundedMixin, EnvObjects):
|
|||||||
def obs_tag(self):
|
def obs_tag(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
@property
|
|
||||||
def obs_pairs(self):
|
|
||||||
return [(x.name, x) for x in self]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Batteries, self).__init__(*args, **kwargs)
|
super(Batteries, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def spawn_batteries(self, agents, initial_charge_level):
|
def spawn(self, agents, initial_charge_level):
|
||||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||||
self.add_items(batteries)
|
self.add_items(batteries)
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(PositionMixin, EnvObjects):
|
class ChargePods(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = ChargePod
|
_entity = ChargePod
|
||||||
|
@ -13,8 +13,8 @@ class Btry(Rule):
|
|||||||
self.per_action_costs = per_action_costs
|
self.per_action_costs = per_action_costs
|
||||||
self.initial_charge = initial_charge
|
self.initial_charge = initial_charge
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge)
|
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
|
||||||
|
|
||||||
def tick_pre_step(self, state) -> List[TickResult]:
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
pass
|
pass
|
||||||
|
@ -0,0 +1,6 @@
|
|||||||
|
from .actions import CleanUp
|
||||||
|
from .entitites import DirtPile
|
||||||
|
from .groups import DirtPiles
|
||||||
|
from .rule_respawn import DirtRespawnRule
|
||||||
|
from .rule_smear_on_move import DirtSmearOnMove
|
||||||
|
from .rule_done_on_all_clean import DirtAllCleanDone
|
||||||
|
@ -7,6 +7,22 @@ from marl_factory_grid.modules.clean_up import constants as d
|
|||||||
|
|
||||||
class DirtPile(Entity):
|
class DirtPile(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_move(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_light(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_has_position(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def amount(self):
|
def amount(self):
|
||||||
return self._amount
|
return self._amount
|
||||||
|
@ -31,7 +31,7 @@ class DirtPiles(PositionMixin, EnvObjects):
|
|||||||
self.max_global_amount = max_global_amount
|
self.max_global_amount = max_global_amount
|
||||||
self.max_local_amount = max_local_amount
|
self.max_local_amount = max_local_amount
|
||||||
|
|
||||||
def spawn_dirt(self, then_dirty_tiles, amount) -> bool:
|
def spawn(self, then_dirty_tiles, amount) -> bool:
|
||||||
if isinstance(then_dirty_tiles, Floor):
|
if isinstance(then_dirty_tiles, Floor):
|
||||||
then_dirty_tiles = [then_dirty_tiles]
|
then_dirty_tiles = [then_dirty_tiles]
|
||||||
for tile in then_dirty_tiles:
|
for tile in then_dirty_tiles:
|
||||||
@ -57,7 +57,7 @@ class DirtPiles(PositionMixin, EnvObjects):
|
|||||||
var = self.dirt_spawn_r_var
|
var = self.dirt_spawn_r_var
|
||||||
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
||||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||||
return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
s = super(DirtPiles, self).__repr__()
|
s = super(DirtPiles, self).__repr__()
|
||||||
|
@ -11,7 +11,7 @@ class DirtRespawnRule(Rule):
|
|||||||
self.spawn_freq = spawn_freq
|
self.spawn_freq = spawn_freq
|
||||||
self._next_dirt_spawn = spawn_freq
|
self._next_dirt_spawn = spawn_freq
|
||||||
|
|
||||||
def on_init(self, state) -> str:
|
def on_init(self, state, lvl_map) -> str:
|
||||||
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
||||||
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class DirtSmearOnMove(Rule):
|
|||||||
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
||||||
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
||||||
if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt):
|
if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt):
|
||||||
results.append(TickResult(identifier=self.name, entity=entity,
|
results.append(TickResult(identifier=self.name, entity=entity,
|
||||||
reward=0, validity=c.VALID))
|
reward=0, validity=c.VALID))
|
||||||
return results
|
return results
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
from .actions import DestAction
|
||||||
|
from .entitites import Destination
|
||||||
|
from .groups import ReachedDestinations, Destinations
|
||||||
|
from .rules import DestinationDone, DestinationReach, DestinationSpawn
|
||||||
|
@ -62,7 +62,7 @@ class DestinationSpawn(Rule):
|
|||||||
self.n_dests = n_dests
|
self.n_dests = n_dests
|
||||||
self.spawn_mode = spawn_mode
|
self.spawn_mode = spawn_mode
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self._dest_spawn_timer = self.spawn_frequency
|
self._dest_spawn_timer = self.spawn_frequency
|
||||||
self.trigger_destination_spawn(self.n_dests, state)
|
self.trigger_destination_spawn(self.n_dests, state)
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
from .actions import DoorUse
|
||||||
|
from .entitites import Door, DoorIndicator
|
||||||
|
from .groups import Doors
|
||||||
|
from .rule_door_auto_close import DoorAutoClose
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from marl_factory_grid.environment.actions import Action
|
from marl_factory_grid.environment.actions import Action
|
||||||
from marl_factory_grid.utils.results import ActionResult
|
|
||||||
|
|
||||||
from marl_factory_grid.modules.doors import constants as d, rewards as r
|
from marl_factory_grid.modules.doors import constants as d, rewards as r
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.utils.results import ActionResult
|
||||||
|
|
||||||
|
|
||||||
class DoorUse(Action):
|
class DoorUse(Action):
|
||||||
|
@ -22,15 +22,15 @@ class DoorIndicator(Entity):
|
|||||||
class Door(Entity):
|
class Door(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_pos(self):
|
def var_is_blocking_pos(self):
|
||||||
return False if self.is_open else True
|
return False if self.is_open else True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocking_light(self):
|
def var_is_blocking_light(self):
|
||||||
return False if self.is_open else True
|
return False if self.is_open else True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def can_collide(self):
|
def var_can_collide(self):
|
||||||
return False if self.is_open else True
|
return False if self.is_open else True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -42,12 +42,14 @@ class Door(Entity):
|
|||||||
return 'open' if self.is_open else 'closed'
|
return 'open' if self.is_open else 'closed'
|
||||||
|
|
||||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
||||||
self._state = d.STATE_CLOSED
|
self._status = d.STATE_CLOSED
|
||||||
super(Door, self).__init__(*args, **kwargs)
|
super(Door, self).__init__(*args, **kwargs)
|
||||||
self.auto_close_interval = auto_close_interval
|
self.auto_close_interval = auto_close_interval
|
||||||
self.time_to_close = 0
|
self.time_to_close = 0
|
||||||
if not closed_on_init:
|
if not closed_on_init:
|
||||||
self._open()
|
self._open()
|
||||||
|
else:
|
||||||
|
self._close()
|
||||||
if indicate_area:
|
if indicate_area:
|
||||||
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
||||||
|
|
||||||
@ -58,22 +60,22 @@ class Door(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_closed(self):
|
def is_closed(self):
|
||||||
return self._state == d.STATE_CLOSED
|
return self._status == d.STATE_CLOSED
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_open(self):
|
def is_open(self):
|
||||||
return self._state == d.STATE_OPEN
|
return self._status == d.STATE_OPEN
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def status(self):
|
def status(self):
|
||||||
return self._state
|
return self._status
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
||||||
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
|
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
|
||||||
|
|
||||||
def use(self):
|
def use(self):
|
||||||
if self._state == d.STATE_OPEN:
|
if self._status == d.STATE_OPEN:
|
||||||
self._close()
|
self._close()
|
||||||
else:
|
else:
|
||||||
self._open()
|
self._open()
|
||||||
@ -90,8 +92,8 @@ class Door(Entity):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def _open(self):
|
def _open(self):
|
||||||
self._state = d.STATE_OPEN
|
self._status = d.STATE_OPEN
|
||||||
self.time_to_close = self.auto_close_interval
|
self.time_to_close = self.auto_close_interval
|
||||||
|
|
||||||
def _close(self):
|
def _close(self):
|
||||||
self._state = d.STATE_CLOSED
|
self._status = d.STATE_CLOSED
|
||||||
|
32
marl_factory_grid/modules/factory/rules.py
Normal file
32
marl_factory_grid/modules/factory/rules.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import random
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.utils.results import TickResult
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSingleZonePlacementBeta(Rule):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
zones = state[c.ZONES]
|
||||||
|
n_zones = state[c.ZONES]
|
||||||
|
agents = state[c.AGENT]
|
||||||
|
if len(self.coordinates) == len(agents):
|
||||||
|
coordinates = self.coordinates
|
||||||
|
elif len(self.coordinates) > len(agents):
|
||||||
|
coordinates = random.choices(self.coordinates, k=len(agents))
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
|
||||||
|
for agent, tile in zip(agents, tiles):
|
||||||
|
agent.move(tile)
|
||||||
|
|
||||||
|
def tick_step(self, state):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
|
return []
|
@ -0,0 +1,4 @@
|
|||||||
|
from .actions import ItemAction
|
||||||
|
from .entitites import Item, DropOffLocation
|
||||||
|
from .groups import DropOffLocations, Items, Inventory, Inventories
|
||||||
|
from .rules import ItemRules
|
||||||
|
@ -8,6 +8,8 @@ from marl_factory_grid.modules.items import constants as i
|
|||||||
|
|
||||||
class Item(Entity):
|
class Item(Entity):
|
||||||
|
|
||||||
|
var_can_collide = False
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
|
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
|
||||||
|
|
||||||
@ -38,6 +40,22 @@ class Item(Entity):
|
|||||||
|
|
||||||
class DropOffLocation(Entity):
|
class DropOffLocation(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_move(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_light(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_has_position(self):
|
||||||
|
return True
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(i.DROP_OFF, self.tile.pos)
|
return RenderEntity(i.DROP_OFF, self.tile.pos)
|
||||||
|
|
||||||
|
@ -17,15 +17,6 @@ class Items(PositionMixin, EnvObjects):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def spawn_items(self, tiles: List[Floor]):
|
|
||||||
items = [self._entity(tile) for tile in tiles]
|
|
||||||
self.add_items(items)
|
|
||||||
|
|
||||||
def despawn_items(self, items: List[Item]):
|
|
||||||
items = [items] if isinstance(items, Item) else items
|
|
||||||
for item in items:
|
|
||||||
del self[item]
|
|
||||||
|
|
||||||
|
|
||||||
class Inventory(IsBoundMixin, EnvObjects):
|
class Inventory(IsBoundMixin, EnvObjects):
|
||||||
|
|
||||||
@ -58,11 +49,7 @@ class Inventory(IsBoundMixin, EnvObjects):
|
|||||||
class Inventories(HasBoundedMixin, Objects):
|
class Inventories(HasBoundedMixin, Objects):
|
||||||
|
|
||||||
_entity = Inventory
|
_entity = Inventory
|
||||||
can_move = False
|
var_can_move = False
|
||||||
|
|
||||||
@property
|
|
||||||
def obs_pairs(self):
|
|
||||||
return [(x.name, x) for x in self]
|
|
||||||
|
|
||||||
def __init__(self, size, *args, **kwargs):
|
def __init__(self, size, *args, **kwargs):
|
||||||
super(Inventories, self).__init__(*args, **kwargs)
|
super(Inventories, self).__init__(*args, **kwargs)
|
||||||
@ -70,7 +57,7 @@ class Inventories(HasBoundedMixin, Objects):
|
|||||||
self._obs = None
|
self._obs = None
|
||||||
self._lazy_eval_transforms = []
|
self._lazy_eval_transforms = []
|
||||||
|
|
||||||
def spawn_inventories(self, agents):
|
def spawn(self, agents):
|
||||||
inventories = [self._entity(agent, self.size,)
|
inventories = [self._entity(agent, self.size,)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
self.add_items(inventories)
|
self.add_items(inventories)
|
||||||
|
@ -18,7 +18,7 @@ class ItemRules(Rule):
|
|||||||
self.max_dropoff_storage_size = max_dropoff_storage_size
|
self.max_dropoff_storage_size = max_dropoff_storage_size
|
||||||
self.n_locations = n_locations
|
self.n_locations = n_locations
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
self.trigger_drop_off_location_spawn(state)
|
self.trigger_drop_off_location_spawn(state)
|
||||||
self._next_item_spawn = self.spawn_frequency
|
self._next_item_spawn = self.spawn_frequency
|
||||||
self.trigger_inventory_spawn(state)
|
self.trigger_inventory_spawn(state)
|
||||||
@ -42,7 +42,7 @@ class ItemRules(Rule):
|
|||||||
def trigger_item_spawn(self, state):
|
def trigger_item_spawn(self, state):
|
||||||
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
|
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
|
||||||
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
|
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||||
state[i.ITEM].spawn_items(empty_tiles)
|
state[i.ITEM].spawn(empty_tiles)
|
||||||
self._next_item_spawn = self.spawn_frequency
|
self._next_item_spawn = self.spawn_frequency
|
||||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||||
return len(empty_tiles)
|
return len(empty_tiles)
|
||||||
@ -52,7 +52,7 @@ class ItemRules(Rule):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trigger_inventory_spawn(state):
|
def trigger_inventory_spawn(state):
|
||||||
state[i.INVENTORY].spawn_inventories(state[c.AGENT])
|
state[i.INVENTORY].spawn(state[c.AGENT])
|
||||||
|
|
||||||
def tick_post_step(self, state) -> List[TickResult]:
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
for item in list(state[i.ITEM].values()):
|
for item in list(state[i.ITEM].values()):
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
from .entitites import Machine
|
||||||
|
from .groups import Machines
|
||||||
|
from .rules import MachineRule
|
||||||
|
25
marl_factory_grid/modules/machines/actions.py
Normal file
25
marl_factory_grid/modules/machines/actions.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from marl_factory_grid.environment.actions import Action
|
||||||
|
from marl_factory_grid.utils.results import ActionResult
|
||||||
|
|
||||||
|
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class MachineAction(Action):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(m.MACHINE_ACTION)
|
||||||
|
|
||||||
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
|
if machine := state[m.MACHINES].by_pos(entity.pos):
|
||||||
|
if valid := machine.maintain():
|
||||||
|
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
||||||
|
else:
|
||||||
|
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
||||||
|
else:
|
||||||
|
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
|||||||
MACHINES = 'Machines'
|
MACHINES = 'Machines'
|
||||||
MACHINE = 'Machine'
|
MACHINE = 'Machine'
|
||||||
|
|
||||||
|
MACHINE_ACTION = 'Maintain'
|
||||||
|
|
||||||
STATE_WORK = 'working'
|
STATE_WORK = 'working'
|
||||||
STATE_IDLE = 'idling'
|
STATE_IDLE = 'idling'
|
||||||
STATE_MAINTAIN = 'maintenance'
|
STATE_MAINTAIN = 'maintenance'
|
||||||
|
@ -2,27 +2,43 @@ from marl_factory_grid.environment.entity.entity import Entity
|
|||||||
from marl_factory_grid.utils.render import RenderEntity
|
from marl_factory_grid.utils.render import RenderEntity
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.utils.results import TickResult
|
from marl_factory_grid.utils.results import TickResult
|
||||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
|
||||||
|
from . import constants as m
|
||||||
|
|
||||||
|
|
||||||
class Machine(Entity):
|
class Machine(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_collide(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_move(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_light(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_has_position(self):
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return self._encodings[self.state]
|
return self._encodings[self.status]
|
||||||
|
|
||||||
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
|
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
|
||||||
super(Machine, self).__init__(*args, **kwargs)
|
super(Machine, self).__init__(*args, **kwargs)
|
||||||
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||||
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||||
|
|
||||||
self.state = m.STATE_IDLE
|
self.status = m.STATE_IDLE
|
||||||
self.health = 100
|
self.health = 100
|
||||||
self._counter = 0
|
self._counter = 0
|
||||||
self.__delattr__('move')
|
|
||||||
|
|
||||||
def maintain(self):
|
def maintain(self):
|
||||||
if self.state == m.STATE_WORK:
|
if self.status == m.STATE_WORK:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
if self.health <= 98:
|
if self.health <= 98:
|
||||||
self.health = 100
|
self.health = 100
|
||||||
@ -31,10 +47,10 @@ class Machine(Entity):
|
|||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def tick(self):
|
def tick(self):
|
||||||
if self.state == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||||
return TickResult(self.name, c.VALID, r.NONE, self)
|
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
||||||
elif self.state == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||||
self.state = m.STATE_WORK
|
self.status = m.STATE_WORK
|
||||||
self.reset_counter()
|
self.reset_counter()
|
||||||
return None
|
return None
|
||||||
elif self._counter:
|
elif self._counter:
|
||||||
@ -42,12 +58,12 @@ class Machine(Entity):
|
|||||||
self.health -= 1
|
self.health -= 1
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
self.state = m.STATE_WORK if self.state == m.STATE_IDLE else m.STATE_IDLE
|
self.status = m.STATE_WORK if self.status == m.STATE_IDLE else m.STATE_IDLE
|
||||||
self.reset_counter()
|
self.reset_counter()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def reset_counter(self):
|
def reset_counter(self):
|
||||||
self._counter = self._intervals[self.state]
|
self._counter = self._intervals[self.status]
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(m.MACHINE, self.pos)
|
return RenderEntity(m.MACHINE, self.pos)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||||
from marl_factory_grid.modules.machines.entitites import Machine
|
|
||||||
|
from .entitites import Machine
|
||||||
|
|
||||||
|
|
||||||
class Machines(PositionMixin, EnvObjects):
|
class Machines(PositionMixin, EnvObjects):
|
||||||
|
BIN
marl_factory_grid/modules/machines/machine.png
Normal file
BIN
marl_factory_grid/modules/machines/machine.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.5 KiB |
@ -12,7 +12,7 @@ class MachineRule(Rule):
|
|||||||
super(MachineRule, self).__init__()
|
super(MachineRule, self).__init__()
|
||||||
self.n_machines = n_machines
|
self.n_machines = n_machines
|
||||||
|
|
||||||
def on_init(self, state):
|
def on_init(self, state, lvl_map):
|
||||||
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines]
|
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines]
|
||||||
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
|
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
|
||||||
|
|
||||||
@ -27,3 +27,9 @@ class MachineRule(Rule):
|
|||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DoneOnBreakRule(Rule):
|
||||||
|
|
||||||
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
|
pass
|
2
marl_factory_grid/modules/maintenance/__init__.py
Normal file
2
marl_factory_grid/modules/maintenance/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .entities import Maintainer
|
||||||
|
from .groups import Maintainers
|
3
marl_factory_grid/modules/maintenance/constants.py
Normal file
3
marl_factory_grid/modules/maintenance/constants.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
|
||||||
|
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
|
||||||
|
|
102
marl_factory_grid/modules/maintenance/entities.py
Normal file
102
marl_factory_grid/modules/maintenance/entities.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import networkx as nx
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...algorithms.static.utils import points_to_graph
|
||||||
|
from ...environment import constants as c
|
||||||
|
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||||
|
from ...environment.entity.entity import Entity
|
||||||
|
from ..doors import constants as do
|
||||||
|
from ..maintenance import constants as mi
|
||||||
|
from ...utils.helpers import MOVEMAP
|
||||||
|
from ...utils.render import RenderEntity
|
||||||
|
from ...utils.states import Gamestate
|
||||||
|
|
||||||
|
|
||||||
|
class Maintainer(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_collide(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_can_move(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_is_blocking_light(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_has_position(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.action = action
|
||||||
|
self.actions = [x() for x in ALL_BASEACTIONS]
|
||||||
|
self.objective = objective
|
||||||
|
self._path = None
|
||||||
|
self._next = []
|
||||||
|
self._last = []
|
||||||
|
self._last_serviced = 'None'
|
||||||
|
self._floortile_graph = points_to_graph(state[c.FLOOR].positions)
|
||||||
|
|
||||||
|
def tick(self, state):
|
||||||
|
if found_objective := state[self.objective].by_pos(self.pos):
|
||||||
|
if found_objective.name != self._last_serviced:
|
||||||
|
self.action.do(self, state)
|
||||||
|
self._last_serviced = found_objective.name
|
||||||
|
else:
|
||||||
|
action = self.get_move_action(state)
|
||||||
|
return action.do(self, state)
|
||||||
|
else:
|
||||||
|
action = self.get_move_action(state)
|
||||||
|
return action.do(self, state)
|
||||||
|
|
||||||
|
def get_move_action(self, state) -> Action:
|
||||||
|
if self._path is None or not self._path:
|
||||||
|
if not self._next:
|
||||||
|
self._next = list(state[self.objective].values())
|
||||||
|
self._last = []
|
||||||
|
self._last.append(self._next.pop())
|
||||||
|
self._path = self.calculate_route(self._last[-1])
|
||||||
|
|
||||||
|
if door := self._door_is_close():
|
||||||
|
if door.is_closed:
|
||||||
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
|
action = do.ACTION_DOOR_USE
|
||||||
|
else:
|
||||||
|
action = self._predict_move(state)
|
||||||
|
else:
|
||||||
|
action = self._predict_move(state)
|
||||||
|
# Translate the action_object to an integer to have the same output as any other model
|
||||||
|
try:
|
||||||
|
action_obj = next(x for x in self.actions if x.name == action)
|
||||||
|
except (StopIteration, UnboundLocalError):
|
||||||
|
print('Will not happen')
|
||||||
|
raise EnvironmentError
|
||||||
|
return action_obj
|
||||||
|
|
||||||
|
def calculate_route(self, entity):
|
||||||
|
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||||
|
return route[1:]
|
||||||
|
|
||||||
|
def _door_is_close(self):
|
||||||
|
try:
|
||||||
|
return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _predict_move(self, state):
|
||||||
|
next_pos = self._path[0]
|
||||||
|
if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
|
||||||
|
action = c.NOOP
|
||||||
|
else:
|
||||||
|
next_pos = self._path.pop(0)
|
||||||
|
diff = np.subtract(next_pos, self.pos)
|
||||||
|
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||||
|
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
||||||
|
return action
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
return RenderEntity(mi.MAINTAINER, self.pos)
|
27
marl_factory_grid/modules/maintenance/groups.py
Normal file
27
marl_factory_grid/modules/maintenance/groups.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .entities import Maintainer
|
||||||
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
|
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||||
|
from ..machines.actions import MachineAction
|
||||||
|
from ...utils.render import RenderEntity
|
||||||
|
from ...utils.states import Gamestate
|
||||||
|
|
||||||
|
from ..machines import constants as mc
|
||||||
|
from . import constants as mi
|
||||||
|
|
||||||
|
|
||||||
|
class Maintainers(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
|
_entity = Maintainer
|
||||||
|
var_can_collide = True
|
||||||
|
var_can_move = True
|
||||||
|
var_is_blocking_light = False
|
||||||
|
var_has_position = True
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def spawn(self, tiles: List[Floor], state: Gamestate):
|
||||||
|
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles])
|
BIN
marl_factory_grid/modules/maintenance/maintainer.png
Normal file
BIN
marl_factory_grid/modules/maintenance/maintainer.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
1
marl_factory_grid/modules/maintenance/rewards.py
Normal file
1
marl_factory_grid/modules/maintenance/rewards.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
MAINTAINER_COLLISION_REWARD = -5
|
39
marl_factory_grid/modules/maintenance/rules.py
Normal file
39
marl_factory_grid/modules/maintenance/rules.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import List
|
||||||
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from . import rewards as r
|
||||||
|
from . import constants as M
|
||||||
|
from marl_factory_grid.utils.states import Gamestate
|
||||||
|
|
||||||
|
|
||||||
|
class MaintenanceRule(Rule):
|
||||||
|
|
||||||
|
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
|
||||||
|
super(MaintenanceRule, self).__init__(*args, **kwargs)
|
||||||
|
self.n_maintainer = n_maintainer
|
||||||
|
|
||||||
|
def on_init(self, state: Gamestate, lvl_map):
|
||||||
|
state[M.MAINTAINERS].spawn(state[c.FLOOR].empty_tiles[:self.n_maintainer], state)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tick_pre_step(self, state) -> List[TickResult]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
|
for maintainer in state[M.MAINTAINERS]:
|
||||||
|
maintainer.tick(state)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_post_step(self, state) -> List[TickResult]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
|
agents = list(state[c.AGENT].values())
|
||||||
|
m_pos = state[M.MAINTAINERS].positions
|
||||||
|
done_results = []
|
||||||
|
for agent in agents:
|
||||||
|
if agent.pos in m_pos:
|
||||||
|
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
||||||
|
reward=r.MAINTAINER_COLLISION_REWARD))
|
||||||
|
return done_results
|
3
marl_factory_grid/modules/zones/__init__.py
Normal file
3
marl_factory_grid/modules/zones/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .entitites import Zone
|
||||||
|
from .groups import Zones
|
||||||
|
from .rules import AgentSingleZonePlacement
|
4
marl_factory_grid/modules/zones/constants.py
Normal file
4
marl_factory_grid/modules/zones/constants.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Names / Identifiers
|
||||||
|
|
||||||
|
ZONES = 'Zones' # Identifier of Zone-objects and groups (groups).
|
||||||
|
ZONE = 'Zone' # -||-
|
21
marl_factory_grid/modules/zones/entitites.py
Normal file
21
marl_factory_grid/modules/zones/entitites.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from marl_factory_grid.environment.entity.entity import Entity
|
||||||
|
from marl_factory_grid.environment.entity.object import Object
|
||||||
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
|
from marl_factory_grid.utils.render import RenderEntity
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
|
from marl_factory_grid.modules.doors import constants as d
|
||||||
|
|
||||||
|
|
||||||
|
class Zone(Object):
|
||||||
|
|
||||||
|
def __init__(self, tiles: List[Floor], *args, **kwargs):
|
||||||
|
super(Zone, self).__init__(*args, **kwargs)
|
||||||
|
self.tiles = tiles
|
||||||
|
|
||||||
|
@property
|
||||||
|
def random_tile(self):
|
||||||
|
return random.choice(self.tiles)
|
12
marl_factory_grid/modules/zones/groups.py
Normal file
12
marl_factory_grid/modules/zones/groups.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from marl_factory_grid.environment.groups.objects import Objects
|
||||||
|
from marl_factory_grid.modules.zones import Zone
|
||||||
|
|
||||||
|
|
||||||
|
class Zones(Objects):
|
||||||
|
|
||||||
|
symbol = None
|
||||||
|
_entity = Zone
|
||||||
|
var_can_move = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
33
marl_factory_grid/modules/zones/rules.py
Normal file
33
marl_factory_grid/modules/zones/rules.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from random import choices
|
||||||
|
|
||||||
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.modules.zones import Zone
|
||||||
|
from . import constants as z
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSingleZonePlacement(Rule):
|
||||||
|
|
||||||
|
def __init__(self, n_zones=3):
|
||||||
|
super().__init__()
|
||||||
|
self.n_zones = n_zones
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
zones = []
|
||||||
|
|
||||||
|
for z_idx in range(1, self.n_zones):
|
||||||
|
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||||
|
assert len(zone_positions)
|
||||||
|
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
|
||||||
|
state[z.ZONES].add_items(zones)
|
||||||
|
|
||||||
|
n_agents = len(state[c.AGENT])
|
||||||
|
assert len(state[z.ZONES]) >= n_agents
|
||||||
|
|
||||||
|
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||||
|
for agent in state[c.AGENT]:
|
||||||
|
agent.move(state[z.ZONES][z_idxs.pop()].random_tile)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def tick_step(self, state):
|
||||||
|
return []
|
@ -10,10 +10,10 @@ def init():
|
|||||||
ce = ConfigExplainer()
|
ce = ConfigExplainer()
|
||||||
cwd = Path(os.getcwd())
|
cwd = Path(os.getcwd())
|
||||||
ce.save_all(cwd / 'full_config.yaml')
|
ce.save_all(cwd / 'full_config.yaml')
|
||||||
template_path = Path(__file__) / 'marl_factory_grid' / 'modules' / '_template'
|
template_path = Path(__file__).parent / 'modules' / '_template'
|
||||||
print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}')
|
print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}')
|
||||||
print('-----------------------------')
|
print('-----------------------------')
|
||||||
print(f'Copying Templates....')
|
print(f'Copying Templates....')
|
||||||
shutil.copytree(template_path, cwd)
|
shutil.copytree(template_path, cwd)
|
||||||
print(f'Templates copied to {template_path.resolve()}')
|
print(f'Templates copied to {cwd}"/"{template_path.name}')
|
||||||
print(':wave:')
|
print(':wave:')
|
||||||
|
@ -18,11 +18,11 @@ class FactoryConfigParser(object):
|
|||||||
default_entites = []
|
default_entites = []
|
||||||
default_rules = ['MaxStepsReached', 'Collision']
|
default_rules = ['MaxStepsReached', 'Collision']
|
||||||
default_actions = [c.MOVE8, c.NOOP]
|
default_actions = [c.MOVE8, c.NOOP]
|
||||||
default_observations = [c.WALLS, c.AGENTS]
|
default_observations = [c.WALLS, c.AGENT]
|
||||||
|
|
||||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||||
self.config_path = Path(config_path)
|
self.config_path = Path(config_path)
|
||||||
self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path
|
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||||
self.config = yaml.safe_load(self.config_path.open())
|
self.config = yaml.safe_load(self.config_path.open())
|
||||||
self.do_record = False
|
self.do_record = False
|
||||||
|
|
||||||
@ -69,12 +69,20 @@ class FactoryConfigParser(object):
|
|||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
try:
|
try:
|
||||||
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
|
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
|
||||||
entity_class = locate_and_import_class(entity, folder_path)
|
|
||||||
except AttributeError:
|
|
||||||
folder_path = self.custom_modules_path
|
|
||||||
entity_class = locate_and_import_class(entity, folder_path)
|
entity_class = locate_and_import_class(entity, folder_path)
|
||||||
|
except AttributeError as e1:
|
||||||
|
try:
|
||||||
|
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
||||||
|
entity_class = locate_and_import_class(entity, folder_path)
|
||||||
|
except AttributeError as e2:
|
||||||
|
try:
|
||||||
|
folder_path = self.custom_modules_path
|
||||||
|
entity_class = locate_and_import_class(entity, folder_path)
|
||||||
|
except AttributeError as e3:
|
||||||
|
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||||
|
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
|
||||||
|
|
||||||
entity_kwargs = self.entities.get(entity, {})
|
entity_kwargs = self.entities.get(entity, {})
|
||||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||||
@ -92,7 +100,7 @@ class FactoryConfigParser(object):
|
|||||||
parsed_actions = list()
|
parsed_actions = list()
|
||||||
for action in actions:
|
for action in actions:
|
||||||
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
||||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
folder_path = Path(__file__).parent.parent / folder_path
|
||||||
try:
|
try:
|
||||||
class_or_classes = locate_and_import_class(action, folder_path)
|
class_or_classes = locate_and_import_class(action, folder_path)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -124,12 +132,15 @@ class FactoryConfigParser(object):
|
|||||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||||
|
|
||||||
for rule in rules:
|
for rule in rules:
|
||||||
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
|
|
||||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
|
||||||
try:
|
try:
|
||||||
|
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||||
rule_class = locate_and_import_class(rule, folder_path)
|
rule_class = locate_and_import_class(rule, folder_path)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
try:
|
||||||
|
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||||
|
rule_class = locate_and_import_class(rule, folder_path)
|
||||||
|
except AttributeError:
|
||||||
|
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||||
rule_kwargs = self.rules.get(rule, {})
|
rule_kwargs = self.rules.get(rule, {})
|
||||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||||
return rules_classes
|
return rules_classes
|
||||||
|
@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str):
|
|||||||
|
|
||||||
grid = np.array(level)
|
grid = np.array(level)
|
||||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||||
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
|
binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
|
||||||
return binary_grid
|
return binary_grid
|
||||||
|
|
||||||
|
|
||||||
@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
|||||||
for module_path in module_paths:
|
for module_path in module_paths:
|
||||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||||
mod = importlib.import_module('.'.join(module_parts))
|
mod = importlib.import_module('.'.join(module_parts))
|
||||||
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
|
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
|
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
||||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
|
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||||
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
|
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||||
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
|
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||||
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any',
|
]])
|
||||||
'inspect']])
|
|
||||||
try:
|
try:
|
||||||
model_class = mod.__getattribute__(class_name)
|
model_class = mod.__getattribute__(class_name)
|
||||||
return model_class
|
return model_class
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
|
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
||||||
f'Check the {folder_path.name} name.\n'
|
|
||||||
f'Possible Options are:\n{set(all_found_modules)}')
|
|
||||||
|
@ -24,31 +24,40 @@ class LevelParser(object):
|
|||||||
self.level_shape = level_array.shape
|
self.level_shape = level_array.shape
|
||||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
||||||
|
|
||||||
|
def get_coordinates_for_symbol(self, symbol, negate=False):
|
||||||
|
level_array = h.one_hot_level(self._parsed_level, symbol)
|
||||||
|
if negate:
|
||||||
|
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
|
||||||
|
else:
|
||||||
|
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
entities = Entities()
|
entities = Entities()
|
||||||
# Walls
|
# Walls
|
||||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||||
|
|
||||||
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
|
|
||||||
entities.add_items({c.WALL: walls})
|
entities.add_items({c.WALL: walls})
|
||||||
|
|
||||||
# Floor
|
# Floor
|
||||||
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size)
|
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
|
||||||
entities.add_items({c.FLOOR: floor})
|
entities.add_items({c.FLOOR: floor})
|
||||||
|
|
||||||
# All other
|
# All other
|
||||||
for es_name in self.e_p_dict:
|
for es_name in self.e_p_dict:
|
||||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||||
|
|
||||||
if hasattr(e_class, 'symbol'):
|
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
||||||
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol)
|
symbols = e_class.symbol
|
||||||
if np.any(level_array):
|
if isinstance(symbols, (str, int, float)):
|
||||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
symbols = [symbols]
|
||||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
for symbol in symbols:
|
||||||
)
|
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||||
else:
|
if np.any(level_array):
|
||||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||||
f'Check your level file!')
|
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||||
|
f'Check your level file!')
|
||||||
else:
|
else:
|
||||||
e = e_class(self.size, **e_kwargs)
|
e = e_class(self.size, **e_kwargs)
|
||||||
entities.add_items({e.name: e})
|
entities.add_items({e.name: e})
|
||||||
|
@ -6,11 +6,10 @@ from typing import Dict, List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from numba import njit
|
from numba import njit
|
||||||
|
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.groups.utils import Combined
|
from marl_factory_grid.environment.groups.utils import Combined
|
||||||
from marl_factory_grid.utils.states import Gamestate
|
from marl_factory_grid.utils.states import Gamestate
|
||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
|
|
||||||
|
|
||||||
class OBSBuilder(object):
|
class OBSBuilder(object):
|
||||||
|
|
||||||
@ -111,10 +110,10 @@ class OBSBuilder(object):
|
|||||||
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
|
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}')
|
f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
positional = e.has_position
|
positional = e.var_has_position
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
positional = False
|
positional = False
|
||||||
if positional:
|
if positional:
|
||||||
@ -172,7 +171,7 @@ class OBSBuilder(object):
|
|||||||
obs_layers.append(combined.name)
|
obs_layers.append(combined.name)
|
||||||
elif obs_str == c.OTHERS:
|
elif obs_str == c.OTHERS:
|
||||||
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
|
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
|
||||||
elif obs_str == c.AGENTS:
|
elif obs_str == c.AGENT:
|
||||||
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
|
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
|
||||||
else:
|
else:
|
||||||
obs_layers.append(obs_str)
|
obs_layers.append(obs_str)
|
||||||
@ -222,7 +221,7 @@ class RayCaster:
|
|||||||
entities_hit = entities.pos_dict[(x, y)]
|
entities_hit = entities.pos_dict[(x, y)]
|
||||||
hits = self.ray_block_cache(cache_blocking,
|
hits = self.ray_block_cache(cache_blocking,
|
||||||
(x, y),
|
(x, y),
|
||||||
lambda: any(e.is_blocking_light for e in entities_hit),
|
lambda: any(e.var_is_blocking_light for e in entities_hit),
|
||||||
entities)
|
entities)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -237,8 +236,8 @@ class RayCaster:
|
|||||||
self.ray_block_cache(
|
self.ray_block_cache(
|
||||||
cache_blocking,
|
cache_blocking,
|
||||||
key,
|
key,
|
||||||
# lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light),
|
# lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light),
|
||||||
lambda: any(e.is_blocking_light for e in entities.pos_dict[key]),
|
lambda: any(e.var_is_blocking_light for e in entities.pos_dict[key]),
|
||||||
entities)
|
entities)
|
||||||
for key in ((x, y-cy), (x-cx, y))
|
for key in ((x, y-cy), (x-cx, y))
|
||||||
]) if (cx != 0 and cy != 0) else False
|
]) if (cx != 0 and cy != 0) else False
|
||||||
|
@ -27,13 +27,13 @@ class Renderer:
|
|||||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
ASSETS = Path(__file__).parent.parent
|
||||||
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
|
|
||||||
|
|
||||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||||
cell_size: int = 40, fps: int = 7,
|
cell_size: int = 40, fps: int = 7,
|
||||||
grid_lines: bool = True, view_radius: int = 2):
|
grid_lines: bool = True, view_radius: int = 2):
|
||||||
|
# TODO: Customn_assets paths
|
||||||
self.grid_h, self.grid_w = lvl_shape
|
self.grid_h, self.grid_w = lvl_shape
|
||||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||||
self.cell_size = cell_size
|
self.cell_size = cell_size
|
||||||
@ -44,7 +44,7 @@ class Renderer:
|
|||||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||||
self.screen = pygame.display.set_mode(self.screen_size)
|
self.screen = pygame.display.set_mode(self.screen_size)
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png'))
|
assets = list(self.ASSETS.rglob('*.png'))
|
||||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from marl_factory_grid.environment.entity.entity import Entity
|
|
||||||
|
|
||||||
TYPE_VALUE = 'value'
|
TYPE_VALUE = 'value'
|
||||||
TYPE_REWARD = 'reward'
|
TYPE_REWARD = 'reward'
|
||||||
types = [TYPE_VALUE, TYPE_REWARD]
|
types = [TYPE_VALUE, TYPE_REWARD]
|
||||||
@ -20,7 +18,7 @@ class Result:
|
|||||||
validity: bool
|
validity: bool
|
||||||
reward: Union[float, None] = None
|
reward: Union[float, None] = None
|
||||||
value: Union[float, None] = None
|
value: Union[float, None] = None
|
||||||
entity: Union[Entity, None] = None
|
entity: None = None
|
||||||
|
|
||||||
def get_infos(self):
|
def get_infos(self):
|
||||||
n = self.entity.name if self.entity is not None else "Global"
|
n = self.entity.name if self.entity is not None else "Global"
|
||||||
|
@ -2,10 +2,11 @@ from typing import List, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
from marl_factory_grid.utils.results import Result
|
from marl_factory_grid.utils.results import Result
|
||||||
from marl_factory_grid.environment import constants as c
|
|
||||||
|
|
||||||
|
|
||||||
class StepRules:
|
class StepRules:
|
||||||
@ -26,9 +27,9 @@ class StepRules:
|
|||||||
self.rules.append(item)
|
self.rules.append(item)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def do_all_init(self, state):
|
def do_all_init(self, state, lvl_map):
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
if rule_init_printline := rule.on_init(state):
|
if rule_init_printline := rule.on_init(state, lvl_map):
|
||||||
state.print(rule_init_printline)
|
state.print(rule_init_printline)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ class Gamestate(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def moving_entites(self):
|
def moving_entites(self):
|
||||||
return [y for x in self.entities for y in x if x.can_move]
|
return [y for x in self.entities for y in x if x.var_can_move]
|
||||||
|
|
||||||
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||||
self.entities = entitites
|
self.entities = entitites
|
||||||
@ -107,6 +108,6 @@ class Gamestate(object):
|
|||||||
|
|
||||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||||
if sum([x.can_collide for x in e]) > 1]
|
if sum([x.var_can_collide for x in e]) > 1]
|
||||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||||
return tiles
|
return tiles
|
||||||
|
@ -15,7 +15,7 @@ ENTITIES = 'Objects'
|
|||||||
OBSERVATIONS = 'Observations'
|
OBSERVATIONS = 'Observations'
|
||||||
RULES = 'Rule'
|
RULES = 'Rule'
|
||||||
ASSETS = 'Assets'
|
ASSETS = 'Assets'
|
||||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move',
|
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
|
||||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,21 +1,6 @@
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
class EnvCombiner(object):
|
|
||||||
|
|
||||||
def __init__(self, *envs_cls):
|
|
||||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def combine_cls(name, *envs_cls):
|
|
||||||
return type(name, envs_cls, {})
|
|
||||||
|
|
||||||
def build(self):
|
|
||||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
|
||||||
|
|
||||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
|
||||||
|
|
||||||
|
|
||||||
class MarlFrameStack(gym.ObservationWrapper):
|
class MarlFrameStack(gym.ObservationWrapper):
|
||||||
"""todo @romue404"""
|
"""todo @romue404"""
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from marl_factory_grid.environment.factory import BaseFactory
|
from marl_factory_grid.environment.factory import Factory
|
||||||
from marl_factory_grid.logging.envmonitor import EnvMonitor
|
from marl_factory_grid.logging.envmonitor import EnvMonitor
|
||||||
from marl_factory_grid.logging.recorder import EnvRecorder
|
from marl_factory_grid.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ if __name__ == '__main__':
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Init Env
|
# Init Env
|
||||||
with BaseFactory(**env_kwargs) as env:
|
with Factory(**env_kwargs) as env:
|
||||||
env = EnvMonitor(env)
|
env = EnvMonitor(env)
|
||||||
env = EnvRecorder(env) if record else env
|
env = EnvRecorder(env) if record else env
|
||||||
obs_shape = env.observation_space.shape
|
obs_shape = env.observation_space.shape
|
||||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text()
|
|||||||
|
|
||||||
|
|
||||||
setup(name='Marl-Factory-Grid',
|
setup(name='Marl-Factory-Grid',
|
||||||
version='0.0.11',
|
version='0.0.12',
|
||||||
description='A framework to research MARL agents in various setings.',
|
description='A framework to research MARL agents in various setings.',
|
||||||
author='Steffen Illium',
|
author='Steffen Illium',
|
||||||
author_email='steffen.illium@ifi.lmu.de',
|
author_email='steffen.illium@ifi.lmu.de',
|
||||||
|
Reference in New Issue
Block a user