mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
Machines
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
from .environment.factory import BaseFactory
|
||||
from .environment.factory import OBSBuilder
|
||||
|
||||
from .utils.tools import ConfigExplainer
|
||||
from .environment import *
|
||||
from .modules import *
|
||||
from .utils import *
|
||||
|
||||
from .quickstart import init
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
import itertools
|
||||
from random import choice
|
||||
|
||||
import numpy as np
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.approximation import traveling_salesman as tsp
|
||||
|
||||
from marl_factory_grid.algorithms.static.utils import points_to_graph
|
||||
from marl_factory_grid.modules.doors import constants as do
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.helpers import MOVEMAP
|
||||
@ -15,41 +14,6 @@ from abc import abstractmethod, ABC
|
||||
future_planning = 7
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
:type: bool
|
||||
|
||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
||||
|
||||
class TSPBaseAgent(ABC):
|
||||
|
||||
def __init__(self, state, agent_i, static_problem: bool = True):
|
||||
|
39
marl_factory_grid/algorithms/static/utils.py
Normal file
39
marl_factory_grid/algorithms/static/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
import itertools
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
:type: bool
|
||||
|
||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
@ -1,68 +1,89 @@
|
||||
---
|
||||
General:
|
||||
level_name: rooms
|
||||
env_seed: 69
|
||||
verbose: !!bool False
|
||||
pomdp_r: 5
|
||||
individual_rewards: !!bool True
|
||||
|
||||
Entities:
|
||||
Defaults: {}
|
||||
DirtPiles:
|
||||
initial_dirt_ratio: 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||
dirt_spawn_r_var: 0.05 # How much does the dirt spawn amount vary?
|
||||
initial_amount: 3
|
||||
max_local_amount: 5 # Max dirt amount per tile.
|
||||
max_global_amount: 20 # Max dirt amount in the whole environment.
|
||||
Doors:
|
||||
closed_on_init: True
|
||||
auto_close_interval: 10
|
||||
indicate_area: False
|
||||
Agents:
|
||||
Wolfgang:
|
||||
Actions:
|
||||
- Move8
|
||||
- Noop
|
||||
- DoorUse
|
||||
- BtryCharge
|
||||
- CleanUp
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Move8
|
||||
Observations:
|
||||
- Self
|
||||
- Placeholder
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Placeholder
|
||||
- Destinations
|
||||
- Doors
|
||||
- Doors
|
||||
Björn:
|
||||
Actions:
|
||||
# Move4, Noop
|
||||
- Move4
|
||||
- DoorUse
|
||||
- CleanUp
|
||||
Observations:
|
||||
- Defaults
|
||||
- Combined
|
||||
Jürgen:
|
||||
Actions:
|
||||
# Move4, Noop
|
||||
- Defaults
|
||||
- DoorUse
|
||||
- CleanUp
|
||||
Observations:
|
||||
- Walls
|
||||
- Placeholder
|
||||
- Agent[Björn]
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Machines
|
||||
- Maintainers
|
||||
Entities:
|
||||
Batteries: {}
|
||||
ChargePods: {}
|
||||
Destinations: {}
|
||||
DirtPiles:
|
||||
clean_amount: 1
|
||||
dirt_spawn_r_var: 0.1
|
||||
initial_amount: 2
|
||||
initial_dirt_ratio: 0.05
|
||||
max_global_amount: 20
|
||||
max_local_amount: 5
|
||||
Doors: {}
|
||||
DropOffLocations: {}
|
||||
GlobalPositions: {}
|
||||
Inventories: {}
|
||||
Items: {}
|
||||
Machines: {}
|
||||
Maintainers: {}
|
||||
Zones: {}
|
||||
ReachedDestinations: {}
|
||||
|
||||
General:
|
||||
env_seed: 69
|
||||
individual_rewards: true
|
||||
level_name: large
|
||||
pomdp_r: 3
|
||||
verbose: false
|
||||
|
||||
Rules:
|
||||
Defaults: {}
|
||||
Btry:
|
||||
initial_charge: 0.8
|
||||
per_action_costs: 0.02
|
||||
BtryDoneAtDischarge: {}
|
||||
Collision:
|
||||
done_at_collisions: !!bool False
|
||||
DirtRespawnRule:
|
||||
spawn_freq: 5
|
||||
DirtSmearOnMove:
|
||||
smear_amount: 0.12
|
||||
DoorAutoClose: {}
|
||||
done_at_collisions: false
|
||||
AssignGlobalPositions: {}
|
||||
DestinationDone: {}
|
||||
DestinationReach:
|
||||
n_dests: 1
|
||||
tiles: null
|
||||
DestinationSpawn:
|
||||
n_dests: 1
|
||||
spawn_frequency: 5
|
||||
spawn_mode: GROUPED
|
||||
DirtAllCleanDone: {}
|
||||
Assets:
|
||||
- Defaults
|
||||
- Dirt
|
||||
- Doors
|
||||
DirtRespawnRule:
|
||||
spawn_freq: 15
|
||||
DirtSmearOnMove:
|
||||
smear_amount: 0.2
|
||||
DoorAutoClose:
|
||||
close_frequency: 10
|
||||
ItemRules:
|
||||
max_dropoff_storage_size: 0
|
||||
n_items: 5
|
||||
n_locations: 5
|
||||
spawn_frequency: 15
|
||||
MachineRule:
|
||||
n_machines: 2
|
||||
MaintenanceRule:
|
||||
n_maintainer: 1
|
||||
MaxStepsReached:
|
||||
max_steps: 500
|
||||
# AgentSingleZonePlacement:
|
||||
# n_zones: 4
|
||||
|
@ -98,3 +98,5 @@ class NorthWest(Move):
|
||||
Move4 = [North, East, South, West]
|
||||
# noinspection PyTypeChecker
|
||||
Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest]
|
||||
|
||||
ALL_BASEACTIONS = Move8 + [Noop]
|
||||
|
@ -9,15 +9,13 @@ WALL = 'Wall' # Identifier of Wall-objects and
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and groups (groups).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and groups (groups).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and groups (groups).
|
||||
AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups).
|
||||
OTHERS = 'Other'
|
||||
COMBINED = 'Combined'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice
|
||||
|
||||
# Attributes
|
||||
IS_BLOCKING_LIGHT = 'is_blocking_light'
|
||||
HAS_POSITION = 'has_position'
|
||||
IS_BLOCKING_LIGHT = 'var_is_blocking_light'
|
||||
HAS_POSITION = 'var_has_position'
|
||||
HAS_NO_POSITION = 'has_no_position'
|
||||
ALL = 'All'
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
@ -8,6 +7,8 @@ from marl_factory_grid.utils import renderer
|
||||
from marl_factory_grid.utils.helpers import is_move
|
||||
from marl_factory_grid.utils.results import ActionResult, Result
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class Agent(Entity):
|
||||
|
||||
@ -24,7 +25,7 @@ class Agent(Entity):
|
||||
return self._observations
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
def step_result(self):
|
||||
|
@ -1,15 +1,20 @@
|
||||
import abc
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import EnvObject
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from .. import constants as c
|
||||
from .object import EnvObject
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.results import ActionResult
|
||||
|
||||
|
||||
class Entity(EnvObject, abc.ABC):
|
||||
"""Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
def state(self):
|
||||
return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0)
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return self.pos != c.VALUE_NO_POS
|
||||
|
||||
@property
|
||||
@ -64,12 +69,13 @@ class Entity(EnvObject, abc.ABC):
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._status = None
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
tile=str(self.tile.name), can_collide=bool(self.var_can_collide))
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
|
@ -78,37 +78,37 @@ class EnvObject(Object):
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
try:
|
||||
return self._collection.is_blocking_light or False
|
||||
return self._collection.var_is_blocking_light or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
def var_can_move(self):
|
||||
try:
|
||||
return self._collection.can_move or False
|
||||
return self._collection.var_can_move or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
try:
|
||||
return self._collection.is_blocking_pos or False
|
||||
return self._collection.var_is_blocking_pos or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
def var_has_position(self):
|
||||
try:
|
||||
return self._collection.has_position or False
|
||||
return self._collection.var_has_position or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
try:
|
||||
return self._collection.can_collide or False
|
||||
return self._collection.var_can_collide or False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
@ -35,11 +35,11 @@ class GlobalPosition(BoundEntityMixin, EnvObject):
|
||||
@property
|
||||
def encoding(self):
|
||||
if self._normalized:
|
||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||
return tuple(np.divide(self._bound_entity.pos, self._shape))
|
||||
else:
|
||||
return self.bound_entity.pos
|
||||
|
||||
def __init__(self, *args, normalized: bool = True, **kwargs):
|
||||
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self._level_shape = math.sqrt(self.size)
|
||||
self._normalized = normalized
|
||||
self._shape = level_shape
|
||||
|
@ -11,23 +11,23 @@ from marl_factory_grid.utils import helpers as h
|
||||
class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def has_position(self):
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_move(self):
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
@ -51,7 +51,7 @@ class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def guests_that_can_collide(self):
|
||||
return [x for x in self.guests if x.can_collide]
|
||||
return [x for x in self.guests if x.var_can_collide]
|
||||
|
||||
@property
|
||||
def guests(self):
|
||||
@ -67,7 +67,7 @@ class Floor(EnvObject):
|
||||
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return any([x.is_blocking_pos for x in self.guests])
|
||||
return any([x.var_is_blocking_pos for x in self.guests])
|
||||
|
||||
def __init__(self, pos, **kwargs):
|
||||
super(Floor, self).__init__(**kwargs)
|
||||
@ -86,7 +86,7 @@ class Floor(EnvObject):
|
||||
return bool(len(self._guests))
|
||||
|
||||
def enter(self, guest):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()):
|
||||
if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()):
|
||||
self._guests.update({guest.name: guest})
|
||||
return c.VALID
|
||||
else:
|
||||
@ -112,7 +112,7 @@ class Floor(EnvObject):
|
||||
class Wall(Floor):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
@ -123,9 +123,9 @@ class Wall(Floor):
|
||||
return RenderEntity(c.WALL, self.pos)
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
return True
|
||||
|
@ -19,7 +19,7 @@ from marl_factory_grid.utils.states import Gamestate
|
||||
REC_TAC = 'rec_'
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
class Factory(gym.Env):
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
@ -52,10 +52,14 @@ class BaseFactory(gym.Env):
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def __init__(self, config_file: Union[str, PathLike]):
|
||||
def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None,
|
||||
custom_level_path: Union[None, PathLike] = None):
|
||||
self._config_file = config_file
|
||||
self.conf = FactoryConfigParser(self._config_file)
|
||||
self.conf = FactoryConfigParser(self._config_file, custom_modules_path)
|
||||
# Attribute Assignment
|
||||
if custom_level_path is not None:
|
||||
self.level_filepath = Path(custom_level_path)
|
||||
else:
|
||||
self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt'
|
||||
self._renderer = None # expensive - don't use; unless required !
|
||||
|
||||
@ -90,7 +94,7 @@ class BaseFactory(gym.Env):
|
||||
self.state.entities.add_item({c.AGENT: agents})
|
||||
|
||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||
self.state.rules.do_all_init(self.state)
|
||||
self.state.rules.do_all_init(self.state, self.map)
|
||||
|
||||
# Observations
|
||||
# noinspection PyAttributeOutsideInit
|
||||
@ -144,7 +148,7 @@ class BaseFactory(gym.Env):
|
||||
try:
|
||||
done_reason = next(x for x in done_check_results if x.validity)
|
||||
done = True
|
||||
self.state.print(f'Env done, Reason: {done_reason.name}.')
|
||||
self.state.print(f'Env done, Reason: {done_reason.identifier}.')
|
||||
except StopIteration:
|
||||
done = False
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
|
||||
|
||||
class Agents(PositionMixin, EnvObjects):
|
||||
|
@ -5,10 +5,10 @@ from marl_factory_grid.environment.entity.object import EnvObject
|
||||
class EnvObjects(Objects):
|
||||
|
||||
_entity = EnvObject
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
has_position: bool = False
|
||||
can_move: bool = False
|
||||
var_is_blocking_light: bool = False
|
||||
var_can_collide: bool = False
|
||||
var_has_position: bool = False
|
||||
var_can_move: bool = False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
@ -19,7 +19,7 @@ class EnvObjects(Objects):
|
||||
self.size = size
|
||||
|
||||
def add_item(self, item: EnvObject):
|
||||
assert self.has_position or (len(self) <= self.size)
|
||||
assert self.var_has_position or (len(self) <= self.size)
|
||||
super(EnvObjects, self).add_item(item)
|
||||
return self
|
||||
|
||||
|
@ -1,15 +1,19 @@
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList
|
||||
class PositionMixin:
|
||||
|
||||
_entity = Entity
|
||||
is_blocking_light: bool = True
|
||||
can_collide: bool = True
|
||||
has_position: bool = True
|
||||
var_is_blocking_light: bool = True
|
||||
var_can_collide: bool = True
|
||||
var_has_position: bool = True
|
||||
|
||||
def spawn(self, tiles: List[Floor]):
|
||||
self.add_items([self._entity(tile) for tile in tiles])
|
||||
|
||||
def render(self):
|
||||
return [y for y in [x.render() for x in self] if y is not None]
|
||||
@ -81,8 +85,8 @@ class IsBoundMixin:
|
||||
class HasBoundedMixin:
|
||||
|
||||
@property
|
||||
def obs_names(self):
|
||||
return [x.name for x in self]
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
|
@ -4,6 +4,7 @@ from typing import List
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
import marl_factory_grid.environment.constants as c
|
||||
|
||||
|
||||
class Objects:
|
||||
@ -116,12 +117,21 @@ class Objects:
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{dict(self._data)}]'
|
||||
|
||||
def spawn(self, n: int):
|
||||
self.add_items([self._entity() for _ in range(n)])
|
||||
return c.VALID
|
||||
|
||||
def despawn(self, items: List[Object]):
|
||||
items = [items] if isinstance(items, Object) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
def notify_change_pos(self, entity: object):
|
||||
try:
|
||||
self.pos_dict[entity.last_pos].remove(entity)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
if entity.has_position:
|
||||
if entity.var_has_position:
|
||||
try:
|
||||
self.pos_dict[entity.pos].append(entity)
|
||||
except (ValueError, AttributeError):
|
||||
|
@ -2,10 +2,11 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
@ -44,7 +45,9 @@ class GlobalPositions(HasBoundedMixin, EnvObjects):
|
||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
class ZonesOLD(Objects):
|
||||
|
||||
_entity = Zone
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
|
@ -30,8 +30,8 @@ class Walls(PositionMixin, EnvObjects):
|
||||
class Floors(Walls):
|
||||
_entity = Floor
|
||||
symbol = c.SYMBOL_FLOOR
|
||||
is_blocking_light: bool = False
|
||||
can_collide: bool = False
|
||||
var_is_blocking_light: bool = False
|
||||
var_can_collide: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Floors, self).__init__(*args, **kwargs)
|
||||
|
@ -17,7 +17,7 @@ class Rule(abc.ABC):
|
||||
def __repr__(self):
|
||||
return f'{self.name}'
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
@ -42,7 +42,7 @@ class MaxStepsReached(Rule):
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
||||
def on_check_done(self, state):
|
||||
@ -51,6 +51,20 @@ class MaxStepsReached(Rule):
|
||||
return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)]
|
||||
|
||||
|
||||
class AssignGlobalPositions(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
from marl_factory_grid.environment.entity.util import GlobalPosition
|
||||
for agent in state[c.AGENT]:
|
||||
gp = GlobalPosition(lvl_map.level_shape)
|
||||
gp.bind_to(agent)
|
||||
state[c.GLOBALPOSITIONS].add_item(gp)
|
||||
return []
|
||||
|
||||
|
||||
class Collision(Rule):
|
||||
|
||||
def __init__(self, done_at_collisions: bool = False):
|
||||
|
@ -0,0 +1,7 @@
|
||||
from .batteries import *
|
||||
from .clean_up import *
|
||||
from .destinations import *
|
||||
from .doors import *
|
||||
from .items import *
|
||||
from .machines import *
|
||||
from .maintenance import *
|
||||
|
@ -8,7 +8,7 @@ class TemplateRule(Rule):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TemplateRule, self).__init__(*args, **kwargs)
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
|
@ -0,0 +1,4 @@
|
||||
from .actions import BtryCharge
|
||||
from .entitites import ChargePod, Battery
|
||||
from .groups import ChargePods, Batteries
|
||||
from .rules import BtryDoneAtDischarge, Btry
|
||||
|
@ -13,18 +13,13 @@ class Batteries(HasBoundedMixin, EnvObjects):
|
||||
def obs_tag(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Batteries, self).__init__(*args, **kwargs)
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
def spawn(self, agents, initial_charge_level):
|
||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||
self.add_items(batteries)
|
||||
|
||||
|
||||
class ChargePods(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = ChargePod
|
||||
|
@ -13,8 +13,8 @@ class Btry(Rule):
|
||||
self.per_action_costs = per_action_costs
|
||||
self.initial_charge = initial_charge
|
||||
|
||||
def on_init(self, state):
|
||||
state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge)
|
||||
def on_init(self, state, lvl_map):
|
||||
state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge)
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
@ -0,0 +1,6 @@
|
||||
from .actions import CleanUp
|
||||
from .entitites import DirtPile
|
||||
from .groups import DirtPiles
|
||||
from .rule_respawn import DirtRespawnRule
|
||||
from .rule_smear_on_move import DirtSmearOnMove
|
||||
from .rule_done_on_all_clean import DirtAllCleanDone
|
||||
|
@ -7,6 +7,22 @@ from marl_factory_grid.modules.clean_up import constants as d
|
||||
|
||||
class DirtPile(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
@ -31,7 +31,7 @@ class DirtPiles(PositionMixin, EnvObjects):
|
||||
self.max_global_amount = max_global_amount
|
||||
self.max_local_amount = max_local_amount
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles, amount) -> bool:
|
||||
def spawn(self, then_dirty_tiles, amount) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
@ -57,7 +57,7 @@ class DirtPiles(PositionMixin, EnvObjects):
|
||||
var = self.dirt_spawn_r_var
|
||||
new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0))
|
||||
n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt)))
|
||||
return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||
return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount)
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
|
@ -11,7 +11,7 @@ class DirtRespawnRule(Rule):
|
||||
self.spawn_freq = spawn_freq
|
||||
self._next_dirt_spawn = spawn_freq
|
||||
|
||||
def on_init(self, state) -> str:
|
||||
def on_init(self, state, lvl_map) -> str:
|
||||
state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True)
|
||||
return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}'
|
||||
|
||||
|
@ -18,7 +18,7 @@ class DirtSmearOnMove(Rule):
|
||||
if is_move(entity.state.identifier) and entity.state.validity == c.VALID:
|
||||
if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos):
|
||||
if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2):
|
||||
if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt):
|
||||
if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt):
|
||||
results.append(TickResult(identifier=self.name, entity=entity,
|
||||
reward=0, validity=c.VALID))
|
||||
return results
|
||||
|
@ -0,0 +1,4 @@
|
||||
from .actions import DestAction
|
||||
from .entitites import Destination
|
||||
from .groups import ReachedDestinations, Destinations
|
||||
from .rules import DestinationDone, DestinationReach, DestinationSpawn
|
||||
|
@ -62,7 +62,7 @@ class DestinationSpawn(Rule):
|
||||
self.n_dests = n_dests
|
||||
self.spawn_mode = spawn_mode
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
# noinspection PyAttributeOutsideInit
|
||||
self._dest_spawn_timer = self.spawn_frequency
|
||||
self.trigger_destination_spawn(self.n_dests, state)
|
||||
|
@ -0,0 +1,4 @@
|
||||
from .actions import DoorUse
|
||||
from .entitites import Door, DoorIndicator
|
||||
from .groups import Doors
|
||||
from .rule_door_auto_close import DoorAutoClose
|
||||
|
@ -1,10 +1,9 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.doors import constants as d, rewards as r
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
|
||||
class DoorUse(Action):
|
||||
|
@ -22,15 +22,15 @@ class DoorIndicator(Entity):
|
||||
class Door(Entity):
|
||||
|
||||
@property
|
||||
def is_blocking_pos(self):
|
||||
def var_is_blocking_pos(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
def var_is_blocking_light(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
def var_can_collide(self):
|
||||
return False if self.is_open else True
|
||||
|
||||
@property
|
||||
@ -42,12 +42,14 @@ class Door(Entity):
|
||||
return 'open' if self.is_open else 'closed'
|
||||
|
||||
def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs):
|
||||
self._state = d.STATE_CLOSED
|
||||
self._status = d.STATE_CLOSED
|
||||
super(Door, self).__init__(*args, **kwargs)
|
||||
self.auto_close_interval = auto_close_interval
|
||||
self.time_to_close = 0
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
else:
|
||||
self._close()
|
||||
if indicate_area:
|
||||
self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor])
|
||||
|
||||
@ -58,22 +60,22 @@ class Door(Entity):
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._state == d.STATE_CLOSED
|
||||
return self._status == d.STATE_CLOSED
|
||||
|
||||
@property
|
||||
def is_open(self):
|
||||
return self._state == d.STATE_OPEN
|
||||
return self._status == d.STATE_OPEN
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
return self._state
|
||||
return self._status
|
||||
|
||||
def render(self):
|
||||
name, state = 'door_open' if self.is_open else 'door_closed', 'blank'
|
||||
return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1)
|
||||
|
||||
def use(self):
|
||||
if self._state == d.STATE_OPEN:
|
||||
if self._status == d.STATE_OPEN:
|
||||
self._close()
|
||||
else:
|
||||
self._open()
|
||||
@ -90,8 +92,8 @@ class Door(Entity):
|
||||
return c.NOT_VALID
|
||||
|
||||
def _open(self):
|
||||
self._state = d.STATE_OPEN
|
||||
self._status = d.STATE_OPEN
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self._state = d.STATE_CLOSED
|
||||
self._status = d.STATE_CLOSED
|
||||
|
32
marl_factory_grid/modules/factory/rules.py
Normal file
32
marl_factory_grid/modules/factory/rules.py
Normal file
@ -0,0 +1,32 @@
|
||||
import random
|
||||
from typing import List, Union
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
|
||||
|
||||
class AgentSingleZonePlacementBeta(Rule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
zones = state[c.ZONES]
|
||||
n_zones = state[c.ZONES]
|
||||
agents = state[c.AGENT]
|
||||
if len(self.coordinates) == len(agents):
|
||||
coordinates = self.coordinates
|
||||
elif len(self.coordinates) > len(agents):
|
||||
coordinates = random.choices(self.coordinates, k=len(agents))
|
||||
else:
|
||||
raise ValueError
|
||||
tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates]
|
||||
for agent, tile in zip(agents, tiles):
|
||||
agent.move(tile)
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
return []
|
@ -0,0 +1,4 @@
|
||||
from .actions import ItemAction
|
||||
from .entitites import Item, DropOffLocation
|
||||
from .groups import DropOffLocations, Items, Inventory, Inventories
|
||||
from .rules import ItemRules
|
||||
|
@ -8,6 +8,8 @@ from marl_factory_grid.modules.items import constants as i
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
var_can_collide = False
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None
|
||||
|
||||
@ -38,6 +40,22 @@ class Item(Entity):
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(i.DROP_OFF, self.tile.pos)
|
||||
|
||||
|
@ -17,15 +17,6 @@ class Items(PositionMixin, EnvObjects):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [self._entity(tile) for tile in tiles]
|
||||
self.add_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(IsBoundMixin, EnvObjects):
|
||||
|
||||
@ -58,11 +49,7 @@ class Inventory(IsBoundMixin, EnvObjects):
|
||||
class Inventories(HasBoundedMixin, Objects):
|
||||
|
||||
_entity = Inventory
|
||||
can_move = False
|
||||
|
||||
@property
|
||||
def obs_pairs(self):
|
||||
return [(x.name, x) for x in self]
|
||||
var_can_move = False
|
||||
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, **kwargs)
|
||||
@ -70,7 +57,7 @@ class Inventories(HasBoundedMixin, Objects):
|
||||
self._obs = None
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def spawn_inventories(self, agents):
|
||||
def spawn(self, agents):
|
||||
inventories = [self._entity(agent, self.size,)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.add_items(inventories)
|
||||
|
@ -18,7 +18,7 @@ class ItemRules(Rule):
|
||||
self.max_dropoff_storage_size = max_dropoff_storage_size
|
||||
self.n_locations = n_locations
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
self.trigger_drop_off_location_spawn(state)
|
||||
self._next_item_spawn = self.spawn_frequency
|
||||
self.trigger_inventory_spawn(state)
|
||||
@ -42,7 +42,7 @@ class ItemRules(Rule):
|
||||
def trigger_item_spawn(self, state):
|
||||
if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))):
|
||||
empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns]
|
||||
state[i.ITEM].spawn_items(empty_tiles)
|
||||
state[i.ITEM].spawn(empty_tiles)
|
||||
self._next_item_spawn = self.spawn_frequency
|
||||
state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}')
|
||||
return len(empty_tiles)
|
||||
@ -52,7 +52,7 @@ class ItemRules(Rule):
|
||||
|
||||
@staticmethod
|
||||
def trigger_inventory_spawn(state):
|
||||
state[i.INVENTORY].spawn_inventories(state[c.AGENT])
|
||||
state[i.INVENTORY].spawn(state[c.AGENT])
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
for item in list(state[i.ITEM].values()):
|
||||
|
@ -0,0 +1,3 @@
|
||||
from .entitites import Machine
|
||||
from .groups import Machines
|
||||
from .rules import MachineRule
|
||||
|
25
marl_factory_grid/modules/machines/actions.py
Normal file
25
marl_factory_grid/modules/machines/actions.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Union
|
||||
|
||||
from marl_factory_grid.environment.actions import Action
|
||||
from marl_factory_grid.utils.results import ActionResult
|
||||
|
||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class MachineAction(Action):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(m.MACHINE_ACTION)
|
||||
|
||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||
if machine := state[m.MACHINES].by_pos(entity.pos):
|
||||
if valid := machine.maintain():
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL)
|
||||
else:
|
||||
return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL)
|
||||
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
MACHINES = 'Machines'
|
||||
MACHINE = 'Machine'
|
||||
|
||||
MACHINE_ACTION = 'Maintain'
|
||||
|
||||
STATE_WORK = 'working'
|
||||
STATE_IDLE = 'idling'
|
||||
STATE_MAINTAIN = 'maintenance'
|
||||
|
@ -2,27 +2,43 @@ from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.results import TickResult
|
||||
from marl_factory_grid.modules.machines import constants as m, rewards as r
|
||||
|
||||
from . import constants as m
|
||||
|
||||
|
||||
class Machine(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self._encodings[self.state]
|
||||
return self._encodings[self.status]
|
||||
|
||||
def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs):
|
||||
super(Machine, self).__init__(*args, **kwargs)
|
||||
self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||
self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval})
|
||||
|
||||
self.state = m.STATE_IDLE
|
||||
self.status = m.STATE_IDLE
|
||||
self.health = 100
|
||||
self._counter = 0
|
||||
self.__delattr__('move')
|
||||
|
||||
def maintain(self):
|
||||
if self.state == m.STATE_WORK:
|
||||
if self.status == m.STATE_WORK:
|
||||
return c.NOT_VALID
|
||||
if self.health <= 98:
|
||||
self.health = 100
|
||||
@ -31,10 +47,10 @@ class Machine(Entity):
|
||||
return c.NOT_VALID
|
||||
|
||||
def tick(self):
|
||||
if self.state == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
return TickResult(self.name, c.VALID, r.NONE, self)
|
||||
elif self.state == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
self.state = m.STATE_WORK
|
||||
if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self)
|
||||
elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]):
|
||||
self.status = m.STATE_WORK
|
||||
self.reset_counter()
|
||||
return None
|
||||
elif self._counter:
|
||||
@ -42,12 +58,12 @@ class Machine(Entity):
|
||||
self.health -= 1
|
||||
return None
|
||||
else:
|
||||
self.state = m.STATE_WORK if self.state == m.STATE_IDLE else m.STATE_IDLE
|
||||
self.status = m.STATE_WORK if self.status == m.STATE_IDLE else m.STATE_IDLE
|
||||
self.reset_counter()
|
||||
return None
|
||||
|
||||
def reset_counter(self):
|
||||
self._counter = self._intervals[self.state]
|
||||
self._counter = self._intervals[self.status]
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(m.MACHINE, self.pos)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from marl_factory_grid.modules.machines.entitites import Machine
|
||||
|
||||
from .entitites import Machine
|
||||
|
||||
|
||||
class Machines(PositionMixin, EnvObjects):
|
||||
|
BIN
marl_factory_grid/modules/machines/machine.png
Normal file
BIN
marl_factory_grid/modules/machines/machine.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 8.5 KiB |
@ -12,7 +12,7 @@ class MachineRule(Rule):
|
||||
super(MachineRule, self).__init__()
|
||||
self.n_machines = n_machines
|
||||
|
||||
def on_init(self, state):
|
||||
def on_init(self, state, lvl_map):
|
||||
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines]
|
||||
state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles)
|
||||
|
||||
@ -27,3 +27,9 @@ class MachineRule(Rule):
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
pass
|
||||
|
||||
|
||||
class DoneOnBreakRule(Rule):
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
pass
|
2
marl_factory_grid/modules/maintenance/__init__.py
Normal file
2
marl_factory_grid/modules/maintenance/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .entities import Maintainer
|
||||
from .groups import Maintainers
|
3
marl_factory_grid/modules/maintenance/constants.py
Normal file
3
marl_factory_grid/modules/maintenance/constants.py
Normal file
@ -0,0 +1,3 @@
|
||||
MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own!
|
||||
MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own!
|
||||
|
102
marl_factory_grid/modules/maintenance/entities.py
Normal file
102
marl_factory_grid/modules/maintenance/entities.py
Normal file
@ -0,0 +1,102 @@
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from ...algorithms.static.utils import points_to_graph
|
||||
from ...environment import constants as c
|
||||
from ...environment.actions import Action, ALL_BASEACTIONS
|
||||
from ...environment.entity.entity import Entity
|
||||
from ..doors import constants as do
|
||||
from ..maintenance import constants as mi
|
||||
from ...utils.helpers import MOVEMAP
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.states import Gamestate
|
||||
|
||||
|
||||
class Maintainer(Entity):
|
||||
|
||||
@property
|
||||
def var_can_collide(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def var_can_move(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def var_has_position(self):
|
||||
return True
|
||||
|
||||
def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action = action
|
||||
self.actions = [x() for x in ALL_BASEACTIONS]
|
||||
self.objective = objective
|
||||
self._path = None
|
||||
self._next = []
|
||||
self._last = []
|
||||
self._last_serviced = 'None'
|
||||
self._floortile_graph = points_to_graph(state[c.FLOOR].positions)
|
||||
|
||||
def tick(self, state):
|
||||
if found_objective := state[self.objective].by_pos(self.pos):
|
||||
if found_objective.name != self._last_serviced:
|
||||
self.action.do(self, state)
|
||||
self._last_serviced = found_objective.name
|
||||
else:
|
||||
action = self.get_move_action(state)
|
||||
return action.do(self, state)
|
||||
else:
|
||||
action = self.get_move_action(state)
|
||||
return action.do(self, state)
|
||||
|
||||
def get_move_action(self, state) -> Action:
|
||||
if self._path is None or not self._path:
|
||||
if not self._next:
|
||||
self._next = list(state[self.objective].values())
|
||||
self._last = []
|
||||
self._last.append(self._next.pop())
|
||||
self._path = self.calculate_route(self._last[-1])
|
||||
|
||||
if door := self._door_is_close():
|
||||
if door.is_closed:
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
action = do.ACTION_DOOR_USE
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
else:
|
||||
action = self._predict_move(state)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(x for x in self.actions if x.name == action)
|
||||
except (StopIteration, UnboundLocalError):
|
||||
print('Will not happen')
|
||||
raise EnvironmentError
|
||||
return action_obj
|
||||
|
||||
def calculate_route(self, entity):
|
||||
route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos)
|
||||
return route[1:]
|
||||
|
||||
def _door_is_close(self):
|
||||
try:
|
||||
return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def _predict_move(self, state):
|
||||
next_pos = self._path[0]
|
||||
if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0:
|
||||
action = c.NOOP
|
||||
else:
|
||||
next_pos = self._path.pop(0)
|
||||
diff = np.subtract(next_pos, self.pos)
|
||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
||||
return action
|
||||
|
||||
def render(self):
|
||||
return RenderEntity(mi.MAINTAINER, self.pos)
|
27
marl_factory_grid/modules/maintenance/groups.py
Normal file
27
marl_factory_grid/modules/maintenance/groups.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import List
|
||||
|
||||
from .entities import Maintainer
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||
from marl_factory_grid.environment.groups.mixins import PositionMixin
|
||||
from ..machines.actions import MachineAction
|
||||
from ...utils.render import RenderEntity
|
||||
from ...utils.states import Gamestate
|
||||
|
||||
from ..machines import constants as mc
|
||||
from . import constants as mi
|
||||
|
||||
|
||||
class Maintainers(PositionMixin, EnvObjects):
|
||||
|
||||
_entity = Maintainer
|
||||
var_can_collide = True
|
||||
var_can_move = True
|
||||
var_is_blocking_light = False
|
||||
var_has_position = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def spawn(self, tiles: List[Floor], state: Gamestate):
|
||||
self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles])
|
BIN
marl_factory_grid/modules/maintenance/maintainer.png
Normal file
BIN
marl_factory_grid/modules/maintenance/maintainer.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
1
marl_factory_grid/modules/maintenance/rewards.py
Normal file
1
marl_factory_grid/modules/maintenance/rewards.py
Normal file
@ -0,0 +1 @@
|
||||
MAINTAINER_COLLISION_REWARD = -5
|
39
marl_factory_grid/modules/maintenance/rules.py
Normal file
39
marl_factory_grid/modules/maintenance/rules.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from . import rewards as r
|
||||
from . import constants as M
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
|
||||
class MaintenanceRule(Rule):
|
||||
|
||||
def __init__(self, n_maintainer: int = 1, *args, **kwargs):
|
||||
super(MaintenanceRule, self).__init__(*args, **kwargs)
|
||||
self.n_maintainer = n_maintainer
|
||||
|
||||
def on_init(self, state: Gamestate, lvl_map):
|
||||
state[M.MAINTAINERS].spawn(state[c.FLOOR].empty_tiles[:self.n_maintainer], state)
|
||||
pass
|
||||
|
||||
def tick_pre_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
for maintainer in state[M.MAINTAINERS]:
|
||||
maintainer.tick(state)
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
pass
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
agents = list(state[c.AGENT].values())
|
||||
m_pos = state[M.MAINTAINERS].positions
|
||||
done_results = []
|
||||
for agent in agents:
|
||||
if agent.pos in m_pos:
|
||||
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
|
||||
reward=r.MAINTAINER_COLLISION_REWARD))
|
||||
return done_results
|
3
marl_factory_grid/modules/zones/__init__.py
Normal file
3
marl_factory_grid/modules/zones/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .entitites import Zone
|
||||
from .groups import Zones
|
||||
from .rules import AgentSingleZonePlacement
|
4
marl_factory_grid/modules/zones/constants.py
Normal file
4
marl_factory_grid/modules/zones/constants.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Names / Identifiers
|
||||
|
||||
ZONES = 'Zones' # Identifier of Zone-objects and groups (groups).
|
||||
ZONE = 'Zone' # -||-
|
21
marl_factory_grid/modules/zones/entitites.py
Normal file
21
marl_factory_grid/modules/zones/entitites.py
Normal file
@ -0,0 +1,21 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
from marl_factory_grid.modules.doors import constants as d
|
||||
|
||||
|
||||
class Zone(Object):
|
||||
|
||||
def __init__(self, tiles: List[Floor], *args, **kwargs):
|
||||
super(Zone, self).__init__(*args, **kwargs)
|
||||
self.tiles = tiles
|
||||
|
||||
@property
|
||||
def random_tile(self):
|
||||
return random.choice(self.tiles)
|
12
marl_factory_grid/modules/zones/groups.py
Normal file
12
marl_factory_grid/modules/zones/groups.py
Normal file
@ -0,0 +1,12 @@
|
||||
from marl_factory_grid.environment.groups.objects import Objects
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
|
||||
|
||||
class Zones(Objects):
|
||||
|
||||
symbol = None
|
||||
_entity = Zone
|
||||
var_can_move = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Zones, self).__init__(*args, can_collide=True, **kwargs)
|
33
marl_factory_grid/modules/zones/rules.py
Normal file
33
marl_factory_grid/modules/zones/rules.py
Normal file
@ -0,0 +1,33 @@
|
||||
from random import choices
|
||||
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.modules.zones import Zone
|
||||
from . import constants as z
|
||||
|
||||
|
||||
class AgentSingleZonePlacement(Rule):
|
||||
|
||||
def __init__(self, n_zones=3):
|
||||
super().__init__()
|
||||
self.n_zones = n_zones
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
zones = []
|
||||
|
||||
for z_idx in range(1, self.n_zones):
|
||||
zone_positions = lvl_map.get_coordinates_for_symbol(z_idx)
|
||||
assert len(zone_positions)
|
||||
zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions]))
|
||||
state[z.ZONES].add_items(zones)
|
||||
|
||||
n_agents = len(state[c.AGENT])
|
||||
assert len(state[z.ZONES]) >= n_agents
|
||||
|
||||
z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents)
|
||||
for agent in state[c.AGENT]:
|
||||
agent.move(state[z.ZONES][z_idxs.pop()].random_tile)
|
||||
return []
|
||||
|
||||
def tick_step(self, state):
|
||||
return []
|
@ -10,10 +10,10 @@ def init():
|
||||
ce = ConfigExplainer()
|
||||
cwd = Path(os.getcwd())
|
||||
ce.save_all(cwd / 'full_config.yaml')
|
||||
template_path = Path(__file__) / 'marl_factory_grid' / 'modules' / '_template'
|
||||
template_path = Path(__file__).parent / 'modules' / '_template'
|
||||
print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}')
|
||||
print('-----------------------------')
|
||||
print(f'Copying Templates....')
|
||||
shutil.copytree(template_path, cwd)
|
||||
print(f'Templates copied to {template_path.resolve()}')
|
||||
print(f'Templates copied to {cwd}"/"{template_path.name}')
|
||||
print(':wave:')
|
||||
|
@ -18,11 +18,11 @@ class FactoryConfigParser(object):
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENTS]
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
self.do_record = False
|
||||
|
||||
@ -69,12 +69,20 @@ class FactoryConfigParser(object):
|
||||
|
||||
for entity in entities:
|
||||
try:
|
||||
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError:
|
||||
except AttributeError as e1:
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e2:
|
||||
try:
|
||||
folder_path = self.custom_modules_path
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents))
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||
@ -92,7 +100,7 @@ class FactoryConfigParser(object):
|
||||
parsed_actions = list()
|
||||
for action in actions:
|
||||
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
folder_path = Path(__file__).parent.parent / folder_path
|
||||
try:
|
||||
class_or_classes = locate_and_import_class(action, folder_path)
|
||||
except AttributeError:
|
||||
@ -124,9 +132,12 @@ class FactoryConfigParser(object):
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
|
||||
for rule in rules:
|
||||
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
|
@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str):
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
|
||||
binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
for module_path in module_paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
|
||||
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
|
||||
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
|
||||
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any',
|
||||
'inspect']])
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any'
|
||||
]])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
|
||||
f'Check the {folder_path.name} name.\n'
|
||||
f'Possible Options are:\n{set(all_found_modules)}')
|
||||
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
||||
|
@ -24,24 +24,33 @@ class LevelParser(object):
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol)
|
||||
if negate:
|
||||
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
|
||||
else:
|
||||
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
||||
|
||||
def do_init(self):
|
||||
entities = Entities()
|
||||
# Walls
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
|
||||
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALL: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size)
|
||||
floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||
|
||||
if hasattr(e_class, 'symbol'):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol)
|
||||
if hasattr(e_class, 'symbol') and e_class.symbol is not None:
|
||||
symbols = e_class.symbol
|
||||
if isinstance(symbols, (str, int, float)):
|
||||
symbols = [symbols]
|
||||
for symbol in symbols:
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||
if np.any(level_array):
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
|
@ -6,11 +6,10 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
|
||||
@ -111,10 +110,10 @@ class OBSBuilder(object):
|
||||
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}')
|
||||
f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}')
|
||||
|
||||
try:
|
||||
positional = e.has_position
|
||||
positional = e.var_has_position
|
||||
except AttributeError:
|
||||
positional = False
|
||||
if positional:
|
||||
@ -172,7 +171,7 @@ class OBSBuilder(object):
|
||||
obs_layers.append(combined.name)
|
||||
elif obs_str == c.OTHERS:
|
||||
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
|
||||
elif obs_str == c.AGENTS:
|
||||
elif obs_str == c.AGENT:
|
||||
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
|
||||
else:
|
||||
obs_layers.append(obs_str)
|
||||
@ -222,7 +221,7 @@ class RayCaster:
|
||||
entities_hit = entities.pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache(cache_blocking,
|
||||
(x, y),
|
||||
lambda: any(e.is_blocking_light for e in entities_hit),
|
||||
lambda: any(e.var_is_blocking_light for e in entities_hit),
|
||||
entities)
|
||||
|
||||
try:
|
||||
@ -237,8 +236,8 @@ class RayCaster:
|
||||
self.ray_block_cache(
|
||||
cache_blocking,
|
||||
key,
|
||||
# lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light),
|
||||
lambda: any(e.is_blocking_light for e in entities.pos_dict[key]),
|
||||
# lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light),
|
||||
lambda: any(e.var_is_blocking_light for e in entities.pos_dict[key]),
|
||||
entities)
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
@ -27,13 +27,13 @@ class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
|
||||
ASSETS = Path(__file__).parent.parent
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
# TODO: Customn_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -44,7 +44,7 @@ class Renderer:
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png'))
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||
self.fill_bg()
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
@ -20,7 +18,7 @@ class Result:
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: Union[Entity, None] = None
|
||||
entity: None = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
|
@ -2,10 +2,11 @@ from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class StepRules:
|
||||
@ -26,9 +27,9 @@ class StepRules:
|
||||
self.rules.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state):
|
||||
def do_all_init(self, state, lvl_map):
|
||||
for rule in self.rules:
|
||||
if rule_init_printline := rule.on_init(state):
|
||||
if rule_init_printline := rule.on_init(state, lvl_map):
|
||||
state.print(rule_init_printline)
|
||||
return c.VALID
|
||||
|
||||
@ -58,7 +59,7 @@ class Gamestate(object):
|
||||
|
||||
@property
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.can_move]
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
self.entities = entitites
|
||||
@ -107,6 +108,6 @@ class Gamestate(object):
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.can_collide for x in e]) > 1]
|
||||
if sum([x.var_can_collide for x in e]) > 1]
|
||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
return tiles
|
||||
|
@ -15,7 +15,7 @@ ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
ASSETS = 'Assets'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move',
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
|
||||
|
||||
|
@ -1,21 +1,6 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCombiner(object):
|
||||
|
||||
def __init__(self, *envs_cls):
|
||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||
|
||||
@staticmethod
|
||||
def combine_cls(name, *envs_cls):
|
||||
return type(name, envs_cls, {})
|
||||
|
||||
def build(self):
|
||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||
|
||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment.factory import BaseFactory
|
||||
from marl_factory_grid.environment.factory import Factory
|
||||
from marl_factory_grid.logging.envmonitor import EnvMonitor
|
||||
from marl_factory_grid.logging.recorder import EnvRecorder
|
||||
|
||||
@ -41,7 +41,7 @@ if __name__ == '__main__':
|
||||
pass
|
||||
|
||||
# Init Env
|
||||
with BaseFactory(**env_kwargs) as env:
|
||||
with Factory(**env_kwargs) as env:
|
||||
env = EnvMonitor(env)
|
||||
env = EnvRecorder(env) if record else env
|
||||
obs_shape = env.observation_space.shape
|
||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text()
|
||||
|
||||
|
||||
setup(name='Marl-Factory-Grid',
|
||||
version='0.0.11',
|
||||
version='0.0.12',
|
||||
description='A framework to research MARL agents in various setings.',
|
||||
author='Steffen Illium',
|
||||
author_email='steffen.illium@ifi.lmu.de',
|
||||
|
Reference in New Issue
Block a user