diff --git a/marl_factory_grid/__init__.py b/marl_factory_grid/__init__.py index 180169b..b2bbfa3 100644 --- a/marl_factory_grid/__init__.py +++ b/marl_factory_grid/__init__.py @@ -1,6 +1,6 @@ -from .environment.factory import BaseFactory -from .environment.factory import OBSBuilder - -from .utils.tools import ConfigExplainer +from .environment import * +from .modules import * +from .utils import * from .quickstart import init + diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index 5cfdb30..ec4cc81 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -1,11 +1,10 @@ -import itertools from random import choice import numpy as np -import networkx as nx from networkx.algorithms.approximation import traveling_salesman as tsp +from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.modules.doors import constants as do from marl_factory_grid.environment import constants as c from marl_factory_grid.utils.helpers import MOVEMAP @@ -15,41 +14,6 @@ from abc import abstractmethod, ABC future_planning = 7 -def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): - """ - Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points. - There are three combinations of settings: - Allow all neigbors: Distance(a, b) <= sqrt(2) - Allow only manhattan: Distance(a, b) == 1 - Allow only euclidean: Distance(a, b) == sqrt(2) - - - :param coordiniates_or_tiles: A set of coordinates. - :type coordiniates_or_tiles: Tiles - :param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors - :type: bool - :param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors - :type: bool - - :return: A graph with nodes that are conneceted as specified by the parameters. - :rtype: nx.Graph - """ - assert allow_euclidean_connections or allow_manhattan_connections - if hasattr(coordiniates_or_tiles, 'positions'): - coordiniates_or_tiles = coordiniates_or_tiles.positions - possible_connections = itertools.combinations(coordiniates_or_tiles, 2) - graph = nx.Graph() - for a, b in possible_connections: - diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) - if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): - graph.add_edge(a, b) - elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): - graph.add_edge(a, b) - elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: - graph.add_edge(a, b) - return graph - - class TSPBaseAgent(ABC): def __init__(self, state, agent_i, static_problem: bool = True): diff --git a/marl_factory_grid/algorithms/static/utils.py b/marl_factory_grid/algorithms/static/utils.py new file mode 100644 index 0000000..2543152 --- /dev/null +++ b/marl_factory_grid/algorithms/static/utils.py @@ -0,0 +1,39 @@ +import itertools + +import networkx as nx +import numpy as np + + +def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True): + """ + Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points. + There are three combinations of settings: + Allow all neigbors: Distance(a, b) <= sqrt(2) + Allow only manhattan: Distance(a, b) == 1 + Allow only euclidean: Distance(a, b) == sqrt(2) + + + :param coordiniates_or_tiles: A set of coordinates. + :type coordiniates_or_tiles: Tiles + :param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors + :type: bool + :param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors + :type: bool + + :return: A graph with nodes that are conneceted as specified by the parameters. + :rtype: nx.Graph + """ + assert allow_euclidean_connections or allow_manhattan_connections + if hasattr(coordiniates_or_tiles, 'positions'): + coordiniates_or_tiles = coordiniates_or_tiles.positions + possible_connections = itertools.combinations(coordiniates_or_tiles, 2) + graph = nx.Graph() + for a, b in possible_connections: + diff = np.linalg.norm(np.asarray(a)-np.asarray(b)) + if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2): + graph.add_edge(a, b) + elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2): + graph.add_edge(a, b) + elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1: + graph.add_edge(a, b) + return graph diff --git a/marl_factory_grid/default_config.yaml b/marl_factory_grid/default_config.yaml index 43e00a3..6bc6f6c 100644 --- a/marl_factory_grid/default_config.yaml +++ b/marl_factory_grid/default_config.yaml @@ -1,68 +1,89 @@ ---- -General: - level_name: rooms - env_seed: 69 - verbose: !!bool False - pomdp_r: 5 - individual_rewards: !!bool True - -Entities: - Defaults: {} - DirtPiles: - initial_dirt_ratio: 0.3 # On INIT, on max how many tiles does the dirt spawn in percent. - dirt_spawn_r_var: 0.05 # How much does the dirt spawn amount vary? - initial_amount: 3 - max_local_amount: 5 # Max dirt amount per tile. - max_global_amount: 20 # Max dirt amount in the whole environment. - Doors: - closed_on_init: True - auto_close_interval: 10 - indicate_area: False Agents: Wolfgang: Actions: - - Move8 - - Noop - - DoorUse - - CleanUp + - Noop + - BtryCharge + - CleanUp + - DestAction + - DoorUse + - ItemAction + - Move8 Observations: - - Self - - Placeholder + - Combined: + - Other - Walls - - DirtPiles - - Placeholder - - Doors - - Doors - Björn: - Actions: - # Move4, Noop - - Move4 - - DoorUse - - CleanUp - Observations: - - Defaults - - Combined - Jürgen: - Actions: - # Move4, Noop - - Defaults - - DoorUse - - CleanUp - Observations: - - Walls - - Placeholder - - Agent[Björn] + - GlobalPosition + - Battery + - ChargePods + - DirtPiles + - Destinations + - Doors + - Items + - Inventory + - DropOffLocations + - Machines + - Maintainers +Entities: + Batteries: {} + ChargePods: {} + Destinations: {} + DirtPiles: + clean_amount: 1 + dirt_spawn_r_var: 0.1 + initial_amount: 2 + initial_dirt_ratio: 0.05 + max_global_amount: 20 + max_local_amount: 5 + Doors: {} + DropOffLocations: {} + GlobalPositions: {} + Inventories: {} + Items: {} + Machines: {} + Maintainers: {} + Zones: {} + ReachedDestinations: {} + +General: + env_seed: 69 + individual_rewards: true + level_name: large + pomdp_r: 3 + verbose: false + Rules: - Defaults: {} + Btry: + initial_charge: 0.8 + per_action_costs: 0.02 + BtryDoneAtDischarge: {} Collision: - done_at_collisions: !!bool False - DirtRespawnRule: - spawn_freq: 5 - DirtSmearOnMove: - smear_amount: 0.12 - DoorAutoClose: {} + done_at_collisions: false + AssignGlobalPositions: {} + DestinationDone: {} + DestinationReach: + n_dests: 1 + tiles: null + DestinationSpawn: + n_dests: 1 + spawn_frequency: 5 + spawn_mode: GROUPED DirtAllCleanDone: {} -Assets: - - Defaults - - Dirt - - Doors + DirtRespawnRule: + spawn_freq: 15 + DirtSmearOnMove: + smear_amount: 0.2 + DoorAutoClose: + close_frequency: 10 + ItemRules: + max_dropoff_storage_size: 0 + n_items: 5 + n_locations: 5 + spawn_frequency: 15 + MachineRule: + n_machines: 2 + MaintenanceRule: + n_maintainer: 1 + MaxStepsReached: + max_steps: 500 +# AgentSingleZonePlacement: +# n_zones: 4 diff --git a/marl_factory_grid/environment/actions.py b/marl_factory_grid/environment/actions.py index 35b5419..21ea463 100644 --- a/marl_factory_grid/environment/actions.py +++ b/marl_factory_grid/environment/actions.py @@ -98,3 +98,5 @@ class NorthWest(Move): Move4 = [North, East, South, West] # noinspection PyTypeChecker Move8 = Move4 + [NorthEast, SouthEast, SouthWest, NorthWest] + +ALL_BASEACTIONS = Move8 + [Noop] diff --git a/marl_factory_grid/environment/constants.py b/marl_factory_grid/environment/constants.py index 5891fbd..cc494f4 100644 --- a/marl_factory_grid/environment/constants.py +++ b/marl_factory_grid/environment/constants.py @@ -9,15 +9,13 @@ WALL = 'Wall' # Identifier of Wall-objects and WALLS = 'Walls' # Identifier of Wall-objects and groups (groups). LEVEL = 'Level' # Identifier of Level-objects and groups (groups). AGENT = 'Agent' # Identifier of Agent-objects and groups (groups). -AGENTS = 'Agents' # Identifier of Agent-objects and groups (groups). OTHERS = 'Other' COMBINED = 'Combined' -GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice - +GLOBALPOSITIONS = 'GlobalPositions' # Identifier of the global position slice # Attributes -IS_BLOCKING_LIGHT = 'is_blocking_light' -HAS_POSITION = 'has_position' +IS_BLOCKING_LIGHT = 'var_is_blocking_light' +HAS_POSITION = 'var_has_position' HAS_NO_POSITION = 'has_no_position' ALL = 'All' diff --git a/marl_factory_grid/environment/entity/agent.py b/marl_factory_grid/environment/entity/agent.py index aa27395..73b9feb 100644 --- a/marl_factory_grid/environment/entity/agent.py +++ b/marl_factory_grid/environment/entity/agent.py @@ -1,6 +1,5 @@ from typing import List, Union -from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.actions import Action from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.utils.render import RenderEntity @@ -8,6 +7,8 @@ from marl_factory_grid.utils import renderer from marl_factory_grid.utils.helpers import is_move from marl_factory_grid.utils.results import ActionResult, Result +from marl_factory_grid.environment import constants as c + class Agent(Entity): @@ -24,7 +25,7 @@ class Agent(Entity): return self._observations @property - def can_collide(self): + def var_can_collide(self): return True def step_result(self): diff --git a/marl_factory_grid/environment/entity/entity.py b/marl_factory_grid/environment/entity/entity.py index 00944e1..bce803a 100644 --- a/marl_factory_grid/environment/entity/entity.py +++ b/marl_factory_grid/environment/entity/entity.py @@ -1,15 +1,20 @@ import abc -from marl_factory_grid.environment import constants as c -from marl_factory_grid.environment.entity.object import EnvObject -from marl_factory_grid.utils.render import RenderEntity +from .. import constants as c +from .object import EnvObject +from ...utils.render import RenderEntity +from ...utils.results import ActionResult class Entity(EnvObject, abc.ABC): """Full Env Entity that lives on the environment Grid. Doors, Items, DirtPile etc...""" @property - def has_position(self): + def state(self): + return self._status or ActionResult(entity=self, identifier=c.NOOP, validity=c.VALID, reward=0) + + @property + def var_has_position(self): return self.pos != c.VALUE_NO_POS @property @@ -64,12 +69,13 @@ class Entity(EnvObject, abc.ABC): def __init__(self, tile, **kwargs): super().__init__(**kwargs) + self._status = None self._tile = tile tile.enter(self) def summarize_state(self) -> dict: return dict(name=str(self.name), x=int(self.x), y=int(self.y), - tile=str(self.tile.name), can_collide=bool(self.can_collide)) + tile=str(self.tile.name), can_collide=bool(self.var_can_collide)) @abc.abstractmethod def render(self): diff --git a/marl_factory_grid/environment/entity/object.py b/marl_factory_grid/environment/entity/object.py index dd98539..0c65552 100644 --- a/marl_factory_grid/environment/entity/object.py +++ b/marl_factory_grid/environment/entity/object.py @@ -78,37 +78,37 @@ class EnvObject(Object): return self.name @property - def is_blocking_light(self): + def var_is_blocking_light(self): try: - return self._collection.is_blocking_light or False + return self._collection.var_is_blocking_light or False except AttributeError: return False @property - def can_move(self): + def var_can_move(self): try: - return self._collection.can_move or False + return self._collection.var_can_move or False except AttributeError: return False @property - def is_blocking_pos(self): + def var_is_blocking_pos(self): try: - return self._collection.is_blocking_pos or False + return self._collection.var_is_blocking_pos or False except AttributeError: return False @property - def has_position(self): + def var_has_position(self): try: - return self._collection.has_position or False + return self._collection.var_has_position or False except AttributeError: return False @property - def can_collide(self): + def var_can_collide(self): try: - return self._collection.can_collide or False + return self._collection.var_can_collide or False except AttributeError: return False diff --git a/marl_factory_grid/environment/entity/util.py b/marl_factory_grid/environment/entity/util.py index 7a43664..fbf0c4a 100644 --- a/marl_factory_grid/environment/entity/util.py +++ b/marl_factory_grid/environment/entity/util.py @@ -35,11 +35,11 @@ class GlobalPosition(BoundEntityMixin, EnvObject): @property def encoding(self): if self._normalized: - return tuple(np.divide(self._bound_entity.pos, self._level_shape)) + return tuple(np.divide(self._bound_entity.pos, self._shape)) else: return self.bound_entity.pos - def __init__(self, *args, normalized: bool = True, **kwargs): + def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): super(GlobalPosition, self).__init__(*args, **kwargs) - self._level_shape = math.sqrt(self.size) self._normalized = normalized + self._shape = level_shape diff --git a/marl_factory_grid/environment/entity/wall_floor.py b/marl_factory_grid/environment/entity/wall_floor.py index 1c920f9..7ad8ed4 100644 --- a/marl_factory_grid/environment/entity/wall_floor.py +++ b/marl_factory_grid/environment/entity/wall_floor.py @@ -11,23 +11,23 @@ from marl_factory_grid.utils import helpers as h class Floor(EnvObject): @property - def has_position(self): + def var_has_position(self): return True @property - def can_collide(self): + def var_can_collide(self): return False @property - def can_move(self): + def var_can_move(self): return False @property - def is_blocking_pos(self): + def var_is_blocking_pos(self): return False @property - def is_blocking_light(self): + def var_is_blocking_light(self): return False @property @@ -51,7 +51,7 @@ class Floor(EnvObject): @property def guests_that_can_collide(self): - return [x for x in self.guests if x.can_collide] + return [x for x in self.guests if x.var_can_collide] @property def guests(self): @@ -67,7 +67,7 @@ class Floor(EnvObject): @property def is_blocked(self): - return any([x.is_blocking_pos for x in self.guests]) + return any([x.var_is_blocking_pos for x in self.guests]) def __init__(self, pos, **kwargs): super(Floor, self).__init__(**kwargs) @@ -86,7 +86,7 @@ class Floor(EnvObject): return bool(len(self._guests)) def enter(self, guest): - if (guest.name not in self._guests and not self.is_blocked) and not (guest.is_blocking_pos and self.is_occupied()): + if (guest.name not in self._guests and not self.is_blocked) and not (guest.var_is_blocking_pos and self.is_occupied()): self._guests.update({guest.name: guest}) return c.VALID else: @@ -112,7 +112,7 @@ class Floor(EnvObject): class Wall(Floor): @property - def can_collide(self): + def var_can_collide(self): return True @property @@ -123,9 +123,9 @@ class Wall(Floor): return RenderEntity(c.WALL, self.pos) @property - def is_blocking_pos(self): + def var_is_blocking_pos(self): return True @property - def is_blocking_light(self): + def var_is_blocking_light(self): return True diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 99571ab..c09fd55 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -19,7 +19,7 @@ from marl_factory_grid.utils.states import Gamestate REC_TAC = 'rec_' -class BaseFactory(gym.Env): +class Factory(gym.Env): @property def action_space(self): @@ -52,11 +52,15 @@ class BaseFactory(gym.Env): def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def __init__(self, config_file: Union[str, PathLike]): + def __init__(self, config_file: Union[str, PathLike], custom_modules_path: Union[None, PathLike] = None, + custom_level_path: Union[None, PathLike] = None): self._config_file = config_file - self.conf = FactoryConfigParser(self._config_file) + self.conf = FactoryConfigParser(self._config_file, custom_modules_path) # Attribute Assignment - self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' + if custom_level_path is not None: + self.level_filepath = Path(custom_level_path) + else: + self.level_filepath = Path(__file__).parent.parent / h.LEVELS_DIR / f'{self.conf.level_name}.txt' self._renderer = None # expensive - don't use; unless required ! parsed_entities = self.conf.load_entities() @@ -90,7 +94,7 @@ class BaseFactory(gym.Env): self.state.entities.add_item({c.AGENT: agents}) # All is set up, trigger additional init (after agent entity spawn etc) - self.state.rules.do_all_init(self.state) + self.state.rules.do_all_init(self.state, self.map) # Observations # noinspection PyAttributeOutsideInit @@ -144,7 +148,7 @@ class BaseFactory(gym.Env): try: done_reason = next(x for x in done_check_results if x.validity) done = True - self.state.print(f'Env done, Reason: {done_reason.name}.') + self.state.print(f'Env done, Reason: {done_reason.identifier}.') except StopIteration: done = False diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py index 14f4a0d..f7839ba 100644 --- a/marl_factory_grid/environment/groups/agents.py +++ b/marl_factory_grid/environment/groups/agents.py @@ -1,6 +1,6 @@ +from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.mixins import PositionMixin -from marl_factory_grid.environment.entity.agent import Agent class Agents(PositionMixin, EnvObjects): diff --git a/marl_factory_grid/environment/groups/env_objects.py b/marl_factory_grid/environment/groups/env_objects.py index 64d5d25..b3efb2b 100644 --- a/marl_factory_grid/environment/groups/env_objects.py +++ b/marl_factory_grid/environment/groups/env_objects.py @@ -5,10 +5,10 @@ from marl_factory_grid.environment.entity.object import EnvObject class EnvObjects(Objects): _entity = EnvObject - is_blocking_light: bool = False - can_collide: bool = False - has_position: bool = False - can_move: bool = False + var_is_blocking_light: bool = False + var_can_collide: bool = False + var_has_position: bool = False + var_can_move: bool = False @property def encodings(self): @@ -19,7 +19,7 @@ class EnvObjects(Objects): self.size = size def add_item(self, item: EnvObject): - assert self.has_position or (len(self) <= self.size) + assert self.var_has_position or (len(self) <= self.size) super(EnvObjects, self).add_item(item) return self diff --git a/marl_factory_grid/environment/groups/mixins.py b/marl_factory_grid/environment/groups/mixins.py index 4c1f934..44e4ab9 100644 --- a/marl_factory_grid/environment/groups/mixins.py +++ b/marl_factory_grid/environment/groups/mixins.py @@ -1,15 +1,19 @@ +from typing import List + from marl_factory_grid.environment import constants as c - from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.environment.entity.wall_floor import Floor -# noinspection PyUnresolvedReferences,PyTypeChecker,PyArgumentList class PositionMixin: _entity = Entity - is_blocking_light: bool = True - can_collide: bool = True - has_position: bool = True + var_is_blocking_light: bool = True + var_can_collide: bool = True + var_has_position: bool = True + + def spawn(self, tiles: List[Floor]): + self.add_items([self._entity(tile) for tile in tiles]) def render(self): return [y for y in [x.render() for x in self] if y is not None] @@ -81,8 +85,8 @@ class IsBoundMixin: class HasBoundedMixin: @property - def obs_names(self): - return [x.name for x in self] + def obs_pairs(self): + return [(x.name, x) for x in self] def by_entity(self, entity): try: diff --git a/marl_factory_grid/environment/groups/objects.py b/marl_factory_grid/environment/groups/objects.py index 4113ee3..5182d18 100644 --- a/marl_factory_grid/environment/groups/objects.py +++ b/marl_factory_grid/environment/groups/objects.py @@ -4,6 +4,7 @@ from typing import List import numpy as np from marl_factory_grid.environment.entity.object import Object +import marl_factory_grid.environment.constants as c class Objects: @@ -116,12 +117,21 @@ class Objects: def __repr__(self): return f'{self.__class__.__name__}[{dict(self._data)}]' + def spawn(self, n: int): + self.add_items([self._entity() for _ in range(n)]) + return c.VALID + + def despawn(self, items: List[Object]): + items = [items] if isinstance(items, Object) else items + for item in items: + del self[item] + def notify_change_pos(self, entity: object): try: self.pos_dict[entity.last_pos].remove(entity) except (ValueError, AttributeError): pass - if entity.has_position: + if entity.var_has_position: try: self.pos_dict[entity.pos].append(entity) except (ValueError, AttributeError): diff --git a/marl_factory_grid/environment/groups/utils.py b/marl_factory_grid/environment/groups/utils.py index 19293a9..fd1d8e8 100644 --- a/marl_factory_grid/environment/groups/utils.py +++ b/marl_factory_grid/environment/groups/utils.py @@ -2,10 +2,11 @@ from typing import List, Union import numpy as np -from marl_factory_grid.environment.groups.env_objects import EnvObjects -from marl_factory_grid.environment.groups.objects import Objects -from marl_factory_grid.environment.groups.mixins import HasBoundedMixin, PositionMixin from marl_factory_grid.environment.entity.util import GlobalPosition +from marl_factory_grid.environment.groups.env_objects import EnvObjects +from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundedMixin +from marl_factory_grid.environment.groups.objects import Objects +from marl_factory_grid.modules.zones import Zone from marl_factory_grid.utils import helpers as h from marl_factory_grid.environment import constants as c @@ -44,7 +45,9 @@ class GlobalPositions(HasBoundedMixin, EnvObjects): super(GlobalPositions, self).__init__(*args, **kwargs) -class Zones(Objects): +class ZonesOLD(Objects): + + _entity = Zone @property def accounting_zones(self): diff --git a/marl_factory_grid/environment/groups/wall_n_floors.py b/marl_factory_grid/environment/groups/wall_n_floors.py index c43a9e9..6acf898 100644 --- a/marl_factory_grid/environment/groups/wall_n_floors.py +++ b/marl_factory_grid/environment/groups/wall_n_floors.py @@ -30,8 +30,8 @@ class Walls(PositionMixin, EnvObjects): class Floors(Walls): _entity = Floor symbol = c.SYMBOL_FLOOR - is_blocking_light: bool = False - can_collide: bool = False + var_is_blocking_light: bool = False + var_can_collide: bool = False def __init__(self, *args, **kwargs): super(Floors, self).__init__(*args, **kwargs) diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 6a329e4..9385841 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -17,7 +17,7 @@ class Rule(abc.ABC): def __repr__(self): return f'{self.name}' - def on_init(self, state): + def on_init(self, state, lvl_map): return [] def on_reset(self): @@ -42,7 +42,7 @@ class MaxStepsReached(Rule): super().__init__() self.max_steps = max_steps - def on_init(self, state): + def on_init(self, state, lvl_map): pass def on_check_done(self, state): @@ -51,6 +51,20 @@ class MaxStepsReached(Rule): return [DoneResult(validity=c.NOT_VALID, identifier=self.name, reward=0)] +class AssignGlobalPositions(Rule): + + def __init__(self): + super().__init__() + + def on_init(self, state, lvl_map): + from marl_factory_grid.environment.entity.util import GlobalPosition + for agent in state[c.AGENT]: + gp = GlobalPosition(lvl_map.level_shape) + gp.bind_to(agent) + state[c.GLOBALPOSITIONS].add_item(gp) + return [] + + class Collision(Rule): def __init__(self, done_at_collisions: bool = False): diff --git a/marl_factory_grid/modules/__init__.py b/marl_factory_grid/modules/__init__.py index e69de29..e802c8c 100644 --- a/marl_factory_grid/modules/__init__.py +++ b/marl_factory_grid/modules/__init__.py @@ -0,0 +1,7 @@ +from .batteries import * +from .clean_up import * +from .destinations import * +from .doors import * +from .items import * +from .machines import * +from .maintenance import * diff --git a/marl_factory_grid/modules/_template/constants.py b/marl_factory_grid/modules/_template/constants.py index 1aa157d..aace680 100644 --- a/marl_factory_grid/modules/_template/constants.py +++ b/marl_factory_grid/modules/_template/constants.py @@ -8,4 +8,4 @@ WEST = 'west' NORTHEAST = 'north_east' SOUTHEAST = 'south_east' SOUTHWEST = 'south_west' -NORTHWEST = 'north_west' \ No newline at end of file +NORTHWEST = 'north_west' diff --git a/marl_factory_grid/modules/_template/rules.py b/marl_factory_grid/modules/_template/rules.py index 2f0b65c..6ed2f2d 100644 --- a/marl_factory_grid/modules/_template/rules.py +++ b/marl_factory_grid/modules/_template/rules.py @@ -8,7 +8,7 @@ class TemplateRule(Rule): def __init__(self, *args, **kwargs): super(TemplateRule, self).__init__(*args, **kwargs) - def on_init(self, state): + def on_init(self, state, lvl_map): pass def tick_pre_step(self, state) -> List[TickResult]: diff --git a/marl_factory_grid/modules/batteries/__init__.py b/marl_factory_grid/modules/batteries/__init__.py index e69de29..52e82d5 100644 --- a/marl_factory_grid/modules/batteries/__init__.py +++ b/marl_factory_grid/modules/batteries/__init__.py @@ -0,0 +1,4 @@ +from .actions import BtryCharge +from .entitites import ChargePod, Battery +from .groups import ChargePods, Batteries +from .rules import BtryDoneAtDischarge, Btry diff --git a/marl_factory_grid/modules/batteries/groups.py b/marl_factory_grid/modules/batteries/groups.py index 61935a4..ad24eab 100644 --- a/marl_factory_grid/modules/batteries/groups.py +++ b/marl_factory_grid/modules/batteries/groups.py @@ -13,18 +13,13 @@ class Batteries(HasBoundedMixin, EnvObjects): def obs_tag(self): return self.__class__.__name__ - @property - def obs_pairs(self): - return [(x.name, x) for x in self] - def __init__(self, *args, **kwargs): super(Batteries, self).__init__(*args, **kwargs) - def spawn_batteries(self, agents, initial_charge_level): + def spawn(self, agents, initial_charge_level): batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] self.add_items(batteries) - class ChargePods(PositionMixin, EnvObjects): _entity = ChargePod diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index 16e685a..e8e7816 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -13,8 +13,8 @@ class Btry(Rule): self.per_action_costs = per_action_costs self.initial_charge = initial_charge - def on_init(self, state): - state[b.BATTERIES].spawn_batteries(state[c.AGENT], self.initial_charge) + def on_init(self, state, lvl_map): + state[b.BATTERIES].spawn(state[c.AGENT], self.initial_charge) def tick_pre_step(self, state) -> List[TickResult]: pass diff --git a/marl_factory_grid/modules/clean_up/__init__.py b/marl_factory_grid/modules/clean_up/__init__.py index e69de29..59ce25d 100644 --- a/marl_factory_grid/modules/clean_up/__init__.py +++ b/marl_factory_grid/modules/clean_up/__init__.py @@ -0,0 +1,6 @@ +from .actions import CleanUp +from .entitites import DirtPile +from .groups import DirtPiles +from .rule_respawn import DirtRespawnRule +from .rule_smear_on_move import DirtSmearOnMove +from .rule_done_on_all_clean import DirtAllCleanDone diff --git a/marl_factory_grid/modules/clean_up/entitites.py b/marl_factory_grid/modules/clean_up/entitites.py index 984e644..8ff50b1 100644 --- a/marl_factory_grid/modules/clean_up/entitites.py +++ b/marl_factory_grid/modules/clean_up/entitites.py @@ -7,6 +7,22 @@ from marl_factory_grid.modules.clean_up import constants as d class DirtPile(Entity): + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_has_position(self): + return True + @property def amount(self): return self._amount diff --git a/marl_factory_grid/modules/clean_up/groups.py b/marl_factory_grid/modules/clean_up/groups.py index f6c532c..e51c382 100644 --- a/marl_factory_grid/modules/clean_up/groups.py +++ b/marl_factory_grid/modules/clean_up/groups.py @@ -31,7 +31,7 @@ class DirtPiles(PositionMixin, EnvObjects): self.max_global_amount = max_global_amount self.max_local_amount = max_local_amount - def spawn_dirt(self, then_dirty_tiles, amount) -> bool: + def spawn(self, then_dirty_tiles, amount) -> bool: if isinstance(then_dirty_tiles, Floor): then_dirty_tiles = [then_dirty_tiles] for tile in then_dirty_tiles: @@ -57,7 +57,7 @@ class DirtPiles(PositionMixin, EnvObjects): var = self.dirt_spawn_r_var new_spawn = abs(self.initial_dirt_ratio + (state.rng.uniform(-var, var) if initial_spawn else 0)) n_dirt_tiles = max(0, int(new_spawn * len(free_for_dirt))) - return self.spawn_dirt(free_for_dirt[:n_dirt_tiles], self.initial_amount) + return self.spawn(free_for_dirt[:n_dirt_tiles], self.initial_amount) def __repr__(self): s = super(DirtPiles, self).__repr__() diff --git a/marl_factory_grid/modules/clean_up/rule_respawn.py b/marl_factory_grid/modules/clean_up/rule_respawn.py index 5cdad95..5a2cb2e 100644 --- a/marl_factory_grid/modules/clean_up/rule_respawn.py +++ b/marl_factory_grid/modules/clean_up/rule_respawn.py @@ -11,7 +11,7 @@ class DirtRespawnRule(Rule): self.spawn_freq = spawn_freq self._next_dirt_spawn = spawn_freq - def on_init(self, state) -> str: + def on_init(self, state, lvl_map) -> str: state[d.DIRT].trigger_dirt_spawn(state, initial_spawn=True) return f'Initial Dirt was spawned on: {[x.pos for x in state[d.DIRT]]}' diff --git a/marl_factory_grid/modules/clean_up/rule_smear_on_move.py b/marl_factory_grid/modules/clean_up/rule_smear_on_move.py index 6e29f01..e6d2822 100644 --- a/marl_factory_grid/modules/clean_up/rule_smear_on_move.py +++ b/marl_factory_grid/modules/clean_up/rule_smear_on_move.py @@ -18,7 +18,7 @@ class DirtSmearOnMove(Rule): if is_move(entity.state.identifier) and entity.state.validity == c.VALID: if old_pos_dirt := state[d.DIRT].by_pos(entity.last_pos): if smeared_dirt := round(old_pos_dirt.amount * self.smear_amount, 2): - if state[d.DIRT].spawn_dirt(entity.tile, amount=smeared_dirt): + if state[d.DIRT].spawn(entity.tile, amount=smeared_dirt): results.append(TickResult(identifier=self.name, entity=entity, reward=0, validity=c.VALID)) return results diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py index e69de29..31d3c27 100644 --- a/marl_factory_grid/modules/destinations/__init__.py +++ b/marl_factory_grid/modules/destinations/__init__.py @@ -0,0 +1,4 @@ +from .actions import DestAction +from .entitites import Destination +from .groups import ReachedDestinations, Destinations +from .rules import DestinationDone, DestinationReach, DestinationSpawn diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index b718801..55c66a9 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -62,7 +62,7 @@ class DestinationSpawn(Rule): self.n_dests = n_dests self.spawn_mode = spawn_mode - def on_init(self, state): + def on_init(self, state, lvl_map): # noinspection PyAttributeOutsideInit self._dest_spawn_timer = self.spawn_frequency self.trigger_destination_spawn(self.n_dests, state) diff --git a/marl_factory_grid/modules/doors/__init__.py b/marl_factory_grid/modules/doors/__init__.py index e69de29..e5dc1cf 100644 --- a/marl_factory_grid/modules/doors/__init__.py +++ b/marl_factory_grid/modules/doors/__init__.py @@ -0,0 +1,4 @@ +from .actions import DoorUse +from .entitites import Door, DoorIndicator +from .groups import Doors +from .rule_door_auto_close import DoorAutoClose diff --git a/marl_factory_grid/modules/doors/actions.py b/marl_factory_grid/modules/doors/actions.py index 80b7288..31a5bf8 100644 --- a/marl_factory_grid/modules/doors/actions.py +++ b/marl_factory_grid/modules/doors/actions.py @@ -1,10 +1,9 @@ from typing import Union from marl_factory_grid.environment.actions import Action -from marl_factory_grid.utils.results import ActionResult - from marl_factory_grid.modules.doors import constants as d, rewards as r from marl_factory_grid.environment import constants as c +from marl_factory_grid.utils.results import ActionResult class DoorUse(Action): diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index 4825726..36933ae 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -22,15 +22,15 @@ class DoorIndicator(Entity): class Door(Entity): @property - def is_blocking_pos(self): + def var_is_blocking_pos(self): return False if self.is_open else True @property - def is_blocking_light(self): + def var_is_blocking_light(self): return False if self.is_open else True @property - def can_collide(self): + def var_can_collide(self): return False if self.is_open else True @property @@ -42,12 +42,14 @@ class Door(Entity): return 'open' if self.is_open else 'closed' def __init__(self, *args, closed_on_init=True, auto_close_interval=10, indicate_area=False, **kwargs): - self._state = d.STATE_CLOSED + self._status = d.STATE_CLOSED super(Door, self).__init__(*args, **kwargs) self.auto_close_interval = auto_close_interval self.time_to_close = 0 if not closed_on_init: self._open() + else: + self._close() if indicate_area: self._collection.add_items([DoorIndicator(x) for x in self.tile.neighboring_floor]) @@ -58,22 +60,22 @@ class Door(Entity): @property def is_closed(self): - return self._state == d.STATE_CLOSED + return self._status == d.STATE_CLOSED @property def is_open(self): - return self._state == d.STATE_OPEN + return self._status == d.STATE_OPEN @property def status(self): - return self._state + return self._status def render(self): name, state = 'door_open' if self.is_open else 'door_closed', 'blank' return RenderEntity(name, self.pos, 1, 'none', state, self.identifier_int + 1) def use(self): - if self._state == d.STATE_OPEN: + if self._status == d.STATE_OPEN: self._close() else: self._open() @@ -90,8 +92,8 @@ class Door(Entity): return c.NOT_VALID def _open(self): - self._state = d.STATE_OPEN + self._status = d.STATE_OPEN self.time_to_close = self.auto_close_interval def _close(self): - self._state = d.STATE_CLOSED + self._status = d.STATE_CLOSED diff --git a/marl_factory_grid/environment/assets/__init__.py b/marl_factory_grid/modules/factory/__init__.py similarity index 100% rename from marl_factory_grid/environment/assets/__init__.py rename to marl_factory_grid/modules/factory/__init__.py diff --git a/marl_factory_grid/modules/factory/rules.py b/marl_factory_grid/modules/factory/rules.py new file mode 100644 index 0000000..82d3b38 --- /dev/null +++ b/marl_factory_grid/modules/factory/rules.py @@ -0,0 +1,32 @@ +import random +from typing import List, Union + +from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.environment import constants as c +from marl_factory_grid.utils.results import TickResult + + +class AgentSingleZonePlacementBeta(Rule): + + def __init__(self): + super().__init__() + + def on_init(self, state, lvl_map): + zones = state[c.ZONES] + n_zones = state[c.ZONES] + agents = state[c.AGENT] + if len(self.coordinates) == len(agents): + coordinates = self.coordinates + elif len(self.coordinates) > len(agents): + coordinates = random.choices(self.coordinates, k=len(agents)) + else: + raise ValueError + tiles = [state[c.FLOOR].by_pos(pos) for pos in coordinates] + for agent, tile in zip(agents, tiles): + agent.move(tile) + + def tick_step(self, state): + return [] + + def tick_post_step(self, state) -> List[TickResult]: + return [] \ No newline at end of file diff --git a/marl_factory_grid/modules/items/__init__.py b/marl_factory_grid/modules/items/__init__.py index e69de29..157c385 100644 --- a/marl_factory_grid/modules/items/__init__.py +++ b/marl_factory_grid/modules/items/__init__.py @@ -0,0 +1,4 @@ +from .actions import ItemAction +from .entitites import Item, DropOffLocation +from .groups import DropOffLocations, Items, Inventory, Inventories +from .rules import ItemRules diff --git a/marl_factory_grid/modules/items/entitites.py b/marl_factory_grid/modules/items/entitites.py index b283b0b..94ddad4 100644 --- a/marl_factory_grid/modules/items/entitites.py +++ b/marl_factory_grid/modules/items/entitites.py @@ -8,6 +8,8 @@ from marl_factory_grid.modules.items import constants as i class Item(Entity): + var_can_collide = False + def render(self): return RenderEntity(i.ITEM, self.tile.pos) if self.pos != c.VALUE_NO_POS else None @@ -38,6 +40,22 @@ class Item(Entity): class DropOffLocation(Entity): + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_has_position(self): + return True + def render(self): return RenderEntity(i.DROP_OFF, self.tile.pos) diff --git a/marl_factory_grid/modules/items/groups.py b/marl_factory_grid/modules/items/groups.py index 3ed3d23..0812b47 100644 --- a/marl_factory_grid/modules/items/groups.py +++ b/marl_factory_grid/modules/items/groups.py @@ -17,15 +17,6 @@ class Items(PositionMixin, EnvObjects): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def spawn_items(self, tiles: List[Floor]): - items = [self._entity(tile) for tile in tiles] - self.add_items(items) - - def despawn_items(self, items: List[Item]): - items = [items] if isinstance(items, Item) else items - for item in items: - del self[item] - class Inventory(IsBoundMixin, EnvObjects): @@ -58,11 +49,7 @@ class Inventory(IsBoundMixin, EnvObjects): class Inventories(HasBoundedMixin, Objects): _entity = Inventory - can_move = False - - @property - def obs_pairs(self): - return [(x.name, x) for x in self] + var_can_move = False def __init__(self, size, *args, **kwargs): super(Inventories, self).__init__(*args, **kwargs) @@ -70,7 +57,7 @@ class Inventories(HasBoundedMixin, Objects): self._obs = None self._lazy_eval_transforms = [] - def spawn_inventories(self, agents): + def spawn(self, agents): inventories = [self._entity(agent, self.size,) for _, agent in enumerate(agents)] self.add_items(inventories) diff --git a/marl_factory_grid/modules/items/rules.py b/marl_factory_grid/modules/items/rules.py index 706677f..9340dc8 100644 --- a/marl_factory_grid/modules/items/rules.py +++ b/marl_factory_grid/modules/items/rules.py @@ -18,7 +18,7 @@ class ItemRules(Rule): self.max_dropoff_storage_size = max_dropoff_storage_size self.n_locations = n_locations - def on_init(self, state): + def on_init(self, state, lvl_map): self.trigger_drop_off_location_spawn(state) self._next_item_spawn = self.spawn_frequency self.trigger_inventory_spawn(state) @@ -42,7 +42,7 @@ class ItemRules(Rule): def trigger_item_spawn(self, state): if item_to_spawns := max(0, (self.n_items - len(state[i.ITEM]))): empty_tiles = state[c.FLOOR].empty_tiles[:item_to_spawns] - state[i.ITEM].spawn_items(empty_tiles) + state[i.ITEM].spawn(empty_tiles) self._next_item_spawn = self.spawn_frequency state.print(f'{item_to_spawns} new items have been spawned; next spawn in {self._next_item_spawn}') return len(empty_tiles) @@ -52,7 +52,7 @@ class ItemRules(Rule): @staticmethod def trigger_inventory_spawn(state): - state[i.INVENTORY].spawn_inventories(state[c.AGENT]) + state[i.INVENTORY].spawn(state[c.AGENT]) def tick_post_step(self, state) -> List[TickResult]: for item in list(state[i.ITEM].values()): diff --git a/marl_factory_grid/modules/machines/__init__.py b/marl_factory_grid/modules/machines/__init__.py index e69de29..36ba51d 100644 --- a/marl_factory_grid/modules/machines/__init__.py +++ b/marl_factory_grid/modules/machines/__init__.py @@ -0,0 +1,3 @@ +from .entitites import Machine +from .groups import Machines +from .rules import MachineRule diff --git a/marl_factory_grid/modules/machines/actions.py b/marl_factory_grid/modules/machines/actions.py new file mode 100644 index 0000000..8f4eaaa --- /dev/null +++ b/marl_factory_grid/modules/machines/actions.py @@ -0,0 +1,25 @@ +from typing import Union + +from marl_factory_grid.environment.actions import Action +from marl_factory_grid.utils.results import ActionResult + +from marl_factory_grid.modules.machines import constants as m, rewards as r +from marl_factory_grid.environment import constants as c + + +class MachineAction(Action): + + def __init__(self): + super().__init__(m.MACHINE_ACTION) + + def do(self, entity, state) -> Union[None, ActionResult]: + if machine := state[m.MACHINES].by_pos(entity.pos): + if valid := machine.maintain(): + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_VALID) + else: + return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=r.MAINTAIN_FAIL) + else: + return ActionResult(entity=entity, identifier=self._identifier, validity=c.NOT_VALID, reward=r.MAINTAIN_FAIL) + + + diff --git a/marl_factory_grid/modules/machines/constants.py b/marl_factory_grid/modules/machines/constants.py index f26a88c..29ce3bc 100644 --- a/marl_factory_grid/modules/machines/constants.py +++ b/marl_factory_grid/modules/machines/constants.py @@ -2,6 +2,8 @@ MACHINES = 'Machines' MACHINE = 'Machine' +MACHINE_ACTION = 'Maintain' + STATE_WORK = 'working' STATE_IDLE = 'idling' STATE_MAINTAIN = 'maintenance' diff --git a/marl_factory_grid/modules/machines/entitites.py b/marl_factory_grid/modules/machines/entitites.py index f53fa48..0bb42fb 100644 --- a/marl_factory_grid/modules/machines/entitites.py +++ b/marl_factory_grid/modules/machines/entitites.py @@ -2,27 +2,43 @@ from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.utils.render import RenderEntity from marl_factory_grid.environment import constants as c from marl_factory_grid.utils.results import TickResult -from marl_factory_grid.modules.machines import constants as m, rewards as r + +from . import constants as m class Machine(Entity): + @property + def var_can_collide(self): + return False + + @property + def var_can_move(self): + return False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_has_position(self): + return True + @property def encoding(self): - return self._encodings[self.state] + return self._encodings[self.status] def __init__(self, *args, work_interval: int = 10, pause_interval: int = 15, **kwargs): super(Machine, self).__init__(*args, **kwargs) self._intervals = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval}) self._encodings = dict({m.STATE_IDLE: pause_interval, m.STATE_WORK: work_interval}) - self.state = m.STATE_IDLE + self.status = m.STATE_IDLE self.health = 100 self._counter = 0 - self.__delattr__('move') def maintain(self): - if self.state == m.STATE_WORK: + if self.status == m.STATE_WORK: return c.NOT_VALID if self.health <= 98: self.health = 100 @@ -31,10 +47,10 @@ class Machine(Entity): return c.NOT_VALID def tick(self): - if self.state == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): - return TickResult(self.name, c.VALID, r.NONE, self) - elif self.state == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): - self.state = m.STATE_WORK + if self.status == m.STATE_MAINTAIN and any([c.AGENT in x.name for x in self.tile.guests]): + return TickResult(identifier=self.name, validity=c.VALID, reward=0, entity=self) + elif self.status == m.STATE_MAINTAIN and not any([c.AGENT in x.name for x in self.tile.guests]): + self.status = m.STATE_WORK self.reset_counter() return None elif self._counter: @@ -42,12 +58,12 @@ class Machine(Entity): self.health -= 1 return None else: - self.state = m.STATE_WORK if self.state == m.STATE_IDLE else m.STATE_IDLE + self.status = m.STATE_WORK if self.status == m.STATE_IDLE else m.STATE_IDLE self.reset_counter() return None def reset_counter(self): - self._counter = self._intervals[self.state] + self._counter = self._intervals[self.status] def render(self): return RenderEntity(m.MACHINE, self.pos) diff --git a/marl_factory_grid/modules/machines/groups.py b/marl_factory_grid/modules/machines/groups.py index b4ee633..f8a27e7 100644 --- a/marl_factory_grid/modules/machines/groups.py +++ b/marl_factory_grid/modules/machines/groups.py @@ -1,6 +1,7 @@ from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.mixins import PositionMixin -from marl_factory_grid.modules.machines.entitites import Machine + +from .entitites import Machine class Machines(PositionMixin, EnvObjects): diff --git a/marl_factory_grid/modules/machines/machine.png b/marl_factory_grid/modules/machines/machine.png new file mode 100644 index 0000000..ba01458 Binary files /dev/null and b/marl_factory_grid/modules/machines/machine.png differ diff --git a/marl_factory_grid/modules/machines/rules.py b/marl_factory_grid/modules/machines/rules.py index 573eda8..04502b8 100644 --- a/marl_factory_grid/modules/machines/rules.py +++ b/marl_factory_grid/modules/machines/rules.py @@ -12,7 +12,7 @@ class MachineRule(Rule): super(MachineRule, self).__init__() self.n_machines = n_machines - def on_init(self, state): + def on_init(self, state, lvl_map): empty_tiles = state[c.FLOOR].empty_tiles[:self.n_machines] state[m.MACHINES].add_items(Machine(tile) for tile in empty_tiles) @@ -27,3 +27,9 @@ class MachineRule(Rule): def on_check_done(self, state) -> List[DoneResult]: pass + + +class DoneOnBreakRule(Rule): + + def on_check_done(self, state) -> List[DoneResult]: + pass \ No newline at end of file diff --git a/marl_factory_grid/modules/maintenance/__init__.py b/marl_factory_grid/modules/maintenance/__init__.py new file mode 100644 index 0000000..84da0db --- /dev/null +++ b/marl_factory_grid/modules/maintenance/__init__.py @@ -0,0 +1,2 @@ +from .entities import Maintainer +from .groups import Maintainers diff --git a/marl_factory_grid/modules/maintenance/constants.py b/marl_factory_grid/modules/maintenance/constants.py new file mode 100644 index 0000000..e0ab12c --- /dev/null +++ b/marl_factory_grid/modules/maintenance/constants.py @@ -0,0 +1,3 @@ +MAINTAINER = 'Maintainer' # TEMPLATE _identifier. Define your own! +MAINTAINERS = 'Maintainers' # TEMPLATE _identifier. Define your own! + diff --git a/marl_factory_grid/modules/maintenance/entities.py b/marl_factory_grid/modules/maintenance/entities.py new file mode 100644 index 0000000..546f8a5 --- /dev/null +++ b/marl_factory_grid/modules/maintenance/entities.py @@ -0,0 +1,102 @@ +import networkx as nx +import numpy as np + +from ...algorithms.static.utils import points_to_graph +from ...environment import constants as c +from ...environment.actions import Action, ALL_BASEACTIONS +from ...environment.entity.entity import Entity +from ..doors import constants as do +from ..maintenance import constants as mi +from ...utils.helpers import MOVEMAP +from ...utils.render import RenderEntity +from ...utils.states import Gamestate + + +class Maintainer(Entity): + + @property + def var_can_collide(self): + return True + + @property + def var_can_move(self): + return False + + @property + def var_is_blocking_light(self): + return False + + @property + def var_has_position(self): + return True + + def __init__(self, state: Gamestate, objective: str, action: Action, *args, **kwargs): + super().__init__(*args, **kwargs) + self.action = action + self.actions = [x() for x in ALL_BASEACTIONS] + self.objective = objective + self._path = None + self._next = [] + self._last = [] + self._last_serviced = 'None' + self._floortile_graph = points_to_graph(state[c.FLOOR].positions) + + def tick(self, state): + if found_objective := state[self.objective].by_pos(self.pos): + if found_objective.name != self._last_serviced: + self.action.do(self, state) + self._last_serviced = found_objective.name + else: + action = self.get_move_action(state) + return action.do(self, state) + else: + action = self.get_move_action(state) + return action.do(self, state) + + def get_move_action(self, state) -> Action: + if self._path is None or not self._path: + if not self._next: + self._next = list(state[self.objective].values()) + self._last = [] + self._last.append(self._next.pop()) + self._path = self.calculate_route(self._last[-1]) + + if door := self._door_is_close(): + if door.is_closed: + # Translate the action_object to an integer to have the same output as any other model + action = do.ACTION_DOOR_USE + else: + action = self._predict_move(state) + else: + action = self._predict_move(state) + # Translate the action_object to an integer to have the same output as any other model + try: + action_obj = next(x for x in self.actions if x.name == action) + except (StopIteration, UnboundLocalError): + print('Will not happen') + raise EnvironmentError + return action_obj + + def calculate_route(self, entity): + route = nx.shortest_path(self._floortile_graph, self.pos, entity.pos) + return route[1:] + + def _door_is_close(self): + try: + return next(y for x in self.tile.neighboring_floor for y in x.guests if do.DOOR in y.name) + except StopIteration: + return None + + def _predict_move(self, state): + next_pos = self._path[0] + if len(state[c.FLOOR].by_pos(next_pos).guests_that_can_collide) > 0: + action = c.NOOP + else: + next_pos = self._path.pop(0) + diff = np.subtract(next_pos, self.pos) + # Retrieve action based on the pos dif (like in: What do I have to do to get there?) + action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) + return action + + def render(self): + return RenderEntity(mi.MAINTAINER, self.pos) diff --git a/marl_factory_grid/modules/maintenance/groups.py b/marl_factory_grid/modules/maintenance/groups.py new file mode 100644 index 0000000..8f9b8b8 --- /dev/null +++ b/marl_factory_grid/modules/maintenance/groups.py @@ -0,0 +1,27 @@ +from typing import List + +from .entities import Maintainer +from marl_factory_grid.environment.entity.wall_floor import Floor +from marl_factory_grid.environment.groups.env_objects import EnvObjects +from marl_factory_grid.environment.groups.mixins import PositionMixin +from ..machines.actions import MachineAction +from ...utils.render import RenderEntity +from ...utils.states import Gamestate + +from ..machines import constants as mc +from . import constants as mi + + +class Maintainers(PositionMixin, EnvObjects): + + _entity = Maintainer + var_can_collide = True + var_can_move = True + var_is_blocking_light = False + var_has_position = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def spawn(self, tiles: List[Floor], state: Gamestate): + self.add_items([self._entity(state, mc.MACHINES, MachineAction(), tile) for tile in tiles]) diff --git a/marl_factory_grid/modules/maintenance/maintainer.png b/marl_factory_grid/modules/maintenance/maintainer.png new file mode 100644 index 0000000..af4a37d Binary files /dev/null and b/marl_factory_grid/modules/maintenance/maintainer.png differ diff --git a/marl_factory_grid/modules/maintenance/rewards.py b/marl_factory_grid/modules/maintenance/rewards.py new file mode 100644 index 0000000..425ac3b --- /dev/null +++ b/marl_factory_grid/modules/maintenance/rewards.py @@ -0,0 +1 @@ +MAINTAINER_COLLISION_REWARD = -5 \ No newline at end of file diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py new file mode 100644 index 0000000..7cb2178 --- /dev/null +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -0,0 +1,39 @@ +from typing import List +from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.utils.results import TickResult, DoneResult +from marl_factory_grid.environment import constants as c +from . import rewards as r +from . import constants as M +from marl_factory_grid.utils.states import Gamestate + + +class MaintenanceRule(Rule): + + def __init__(self, n_maintainer: int = 1, *args, **kwargs): + super(MaintenanceRule, self).__init__(*args, **kwargs) + self.n_maintainer = n_maintainer + + def on_init(self, state: Gamestate, lvl_map): + state[M.MAINTAINERS].spawn(state[c.FLOOR].empty_tiles[:self.n_maintainer], state) + pass + + def tick_pre_step(self, state) -> List[TickResult]: + pass + + def tick_step(self, state) -> List[TickResult]: + for maintainer in state[M.MAINTAINERS]: + maintainer.tick(state) + return [] + + def tick_post_step(self, state) -> List[TickResult]: + pass + + def on_check_done(self, state) -> List[DoneResult]: + agents = list(state[c.AGENT].values()) + m_pos = state[M.MAINTAINERS].positions + done_results = [] + for agent in agents: + if agent.pos in m_pos: + done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, + reward=r.MAINTAINER_COLLISION_REWARD)) + return done_results diff --git a/marl_factory_grid/modules/zones/__init__.py b/marl_factory_grid/modules/zones/__init__.py new file mode 100644 index 0000000..7ae5f3c --- /dev/null +++ b/marl_factory_grid/modules/zones/__init__.py @@ -0,0 +1,3 @@ +from .entitites import Zone +from .groups import Zones +from .rules import AgentSingleZonePlacement diff --git a/marl_factory_grid/modules/zones/constants.py b/marl_factory_grid/modules/zones/constants.py new file mode 100644 index 0000000..135a471 --- /dev/null +++ b/marl_factory_grid/modules/zones/constants.py @@ -0,0 +1,4 @@ +# Names / Identifiers + +ZONES = 'Zones' # Identifier of Zone-objects and groups (groups). +ZONE = 'Zone' # -||- diff --git a/marl_factory_grid/modules/zones/entitites.py b/marl_factory_grid/modules/zones/entitites.py new file mode 100644 index 0000000..cd5aa21 --- /dev/null +++ b/marl_factory_grid/modules/zones/entitites.py @@ -0,0 +1,21 @@ +import random +from typing import List + +from marl_factory_grid.environment.entity.entity import Entity +from marl_factory_grid.environment.entity.object import Object +from marl_factory_grid.environment.entity.wall_floor import Floor +from marl_factory_grid.utils.render import RenderEntity +from marl_factory_grid.environment import constants as c + +from marl_factory_grid.modules.doors import constants as d + + +class Zone(Object): + + def __init__(self, tiles: List[Floor], *args, **kwargs): + super(Zone, self).__init__(*args, **kwargs) + self.tiles = tiles + + @property + def random_tile(self): + return random.choice(self.tiles) diff --git a/marl_factory_grid/modules/zones/groups.py b/marl_factory_grid/modules/zones/groups.py new file mode 100644 index 0000000..a26b3a4 --- /dev/null +++ b/marl_factory_grid/modules/zones/groups.py @@ -0,0 +1,12 @@ +from marl_factory_grid.environment.groups.objects import Objects +from marl_factory_grid.modules.zones import Zone + + +class Zones(Objects): + + symbol = None + _entity = Zone + var_can_move = False + + def __init__(self, *args, **kwargs): + super(Zones, self).__init__(*args, can_collide=True, **kwargs) diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py new file mode 100644 index 0000000..8df8be2 --- /dev/null +++ b/marl_factory_grid/modules/zones/rules.py @@ -0,0 +1,33 @@ +from random import choices + +from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.environment import constants as c +from marl_factory_grid.modules.zones import Zone +from . import constants as z + + +class AgentSingleZonePlacement(Rule): + + def __init__(self, n_zones=3): + super().__init__() + self.n_zones = n_zones + + def on_init(self, state, lvl_map): + zones = [] + + for z_idx in range(1, self.n_zones): + zone_positions = lvl_map.get_coordinates_for_symbol(z_idx) + assert len(zone_positions) + zones.append(Zone([state[c.FLOOR].by_pos(pos) for pos in zone_positions])) + state[z.ZONES].add_items(zones) + + n_agents = len(state[c.AGENT]) + assert len(state[z.ZONES]) >= n_agents + + z_idxs = choices(list(range(len(state[z.ZONES]))), k=n_agents) + for agent in state[c.AGENT]: + agent.move(state[z.ZONES][z_idxs.pop()].random_tile) + return [] + + def tick_step(self, state): + return [] diff --git a/marl_factory_grid/quickstart.py b/marl_factory_grid/quickstart.py index 3b96e2a..7e44610 100644 --- a/marl_factory_grid/quickstart.py +++ b/marl_factory_grid/quickstart.py @@ -10,10 +10,10 @@ def init(): ce = ConfigExplainer() cwd = Path(os.getcwd()) ce.save_all(cwd / 'full_config.yaml') - template_path = Path(__file__) / 'marl_factory_grid' / 'modules' / '_template' + template_path = Path(__file__).parent / 'modules' / '_template' print(f'Available config options saved to: {(cwd / "full_config.yaml").resolve()}') print('-----------------------------') print(f'Copying Templates....') shutil.copytree(template_path, cwd) - print(f'Templates copied to {template_path.resolve()}') + print(f'Templates copied to {cwd}"/"{template_path.name}') print(':wave:') diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index 2bf8639..9ac8234 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -18,11 +18,11 @@ class FactoryConfigParser(object): default_entites = [] default_rules = ['MaxStepsReached', 'Collision'] default_actions = [c.MOVE8, c.NOOP] - default_observations = [c.WALLS, c.AGENTS] + default_observations = [c.WALLS, c.AGENT] def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None): self.config_path = Path(config_path) - self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path + self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.config = yaml.safe_load(self.config_path.open()) self.do_record = False @@ -69,12 +69,20 @@ class FactoryConfigParser(object): for entity in entities: try: - folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH - folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path) - entity_class = locate_and_import_class(entity, folder_path) - except AttributeError: - folder_path = self.custom_modules_path + folder_path = Path(__file__).parent.parent / DEFAULT_PATH entity_class = locate_and_import_class(entity, folder_path) + except AttributeError as e1: + try: + folder_path = Path(__file__).parent.parent / MODULE_PATH + entity_class = locate_and_import_class(entity, folder_path) + except AttributeError as e2: + try: + folder_path = self.custom_modules_path + entity_class = locate_and_import_class(entity, folder_path) + except AttributeError as e3: + ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x] + raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are>:', str(ents)) + entity_kwargs = self.entities.get(entity, {}) entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}}) @@ -92,7 +100,7 @@ class FactoryConfigParser(object): parsed_actions = list() for action in actions: folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH - folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path) + folder_path = Path(__file__).parent.parent / folder_path try: class_or_classes = locate_and_import_class(action, folder_path) except AttributeError: @@ -124,12 +132,15 @@ class FactoryConfigParser(object): rules.extend(x for x in self.rules if x != c.DEFAULTS) for rule in rules: - folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH - folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path) try: + folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) rule_class = locate_and_import_class(rule, folder_path) except AttributeError: - rule_class = locate_and_import_class(rule, self.custom_modules_path) + try: + folder_path = (Path(__file__).parent.parent / MODULE_PATH) + rule_class = locate_and_import_class(rule, folder_path) + except AttributeError: + rule_class = locate_and_import_class(rule, self.custom_modules_path) rule_kwargs = self.rules.get(rule, {}) rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) return rules_classes diff --git a/marl_factory_grid/utils/helpers.py b/marl_factory_grid/utils/helpers.py index c1d469e..ca3a20c 100644 --- a/marl_factory_grid/utils/helpers.py +++ b/marl_factory_grid/utils/helpers.py @@ -176,7 +176,7 @@ def one_hot_level(level, symbol: str): grid = np.array(level) binary_grid = np.zeros(grid.shape, dtype=np.int8) - binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL + binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL return binary_grid @@ -222,18 +222,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): for module_path in module_paths: module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos] mod = importlib.import_module('.'.join(module_parts)) - all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper()) - and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor' - 'TickResult', 'ActionResult', 'Action', 'Agent', 'deque', - 'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict', - 'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject', - 'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any', - 'inspect']]) + all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle()) + and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'Floor' + 'TickResult', 'ActionResult', 'Action', 'Agent', 'BoundEntityMixin', + 'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin', + 'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any' + ]]) try: model_class = mod.__getattribute__(class_name) return model_class except AttributeError: continue - raise AttributeError(f'Class "{class_name}" was not found!!!"\n' - f'Check the {folder_path.name} name.\n' - f'Possible Options are:\n{set(all_found_modules)}') + raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules))) diff --git a/marl_factory_grid/utils/level_parser.py b/marl_factory_grid/utils/level_parser.py index e603a63..69bcaeb 100644 --- a/marl_factory_grid/utils/level_parser.py +++ b/marl_factory_grid/utils/level_parser.py @@ -24,31 +24,40 @@ class LevelParser(object): self.level_shape = level_array.shape self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape) + def get_coordinates_for_symbol(self, symbol, negate=False): + level_array = h.one_hot_level(self._parsed_level, symbol) + if negate: + return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL) + else: + return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL) + def do_init(self): entities = Entities() # Walls - level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL) - - walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size) + walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size) entities.add_items({c.WALL: walls}) # Floor - floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size) + floor = Floors.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True), self.size) entities.add_items({c.FLOOR: floor}) # All other for es_name in self.e_p_dict: e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs'] - if hasattr(e_class, 'symbol'): - level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol) - if np.any(level_array): - e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(), - entities[c.FLOOR], self.size, entity_kwargs=e_kwargs - ) - else: - raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n' - f'Check your level file!') + if hasattr(e_class, 'symbol') and e_class.symbol is not None: + symbols = e_class.symbol + if isinstance(symbols, (str, int, float)): + symbols = [symbols] + for symbol in symbols: + level_array = h.one_hot_level(self._parsed_level, symbol=symbol) + if np.any(level_array): + e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(), + entities[c.FLOOR], self.size, entity_kwargs=e_kwargs + ) + else: + raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n' + f'Check your level file!') else: e = e_class(self.size, **e_kwargs) entities.add_items({e.name: e}) diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 6f6717d..ed725cc 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -6,11 +6,10 @@ from typing import Dict, List import numpy as np from numba import njit +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.groups.utils import Combined from marl_factory_grid.utils.states import Gamestate -from marl_factory_grid.environment import constants as c - class OBSBuilder(object): @@ -111,10 +110,10 @@ class OBSBuilder(object): e = next(x for x in self.all_obs if l_name in x and agent.name in x) except StopIteration: raise KeyError( - f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}') + f'Check typing! {l_name} could not be found in: {list(dict(self.all_obs).keys())}') try: - positional = e.has_position + positional = e.var_has_position except AttributeError: positional = False if positional: @@ -172,7 +171,7 @@ class OBSBuilder(object): obs_layers.append(combined.name) elif obs_str == c.OTHERS: obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')]) - elif obs_str == c.AGENTS: + elif obs_str == c.AGENT: obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')]) else: obs_layers.append(obs_str) @@ -222,7 +221,7 @@ class RayCaster: entities_hit = entities.pos_dict[(x, y)] hits = self.ray_block_cache(cache_blocking, (x, y), - lambda: any(e.is_blocking_light for e in entities_hit), + lambda: any(e.var_is_blocking_light for e in entities_hit), entities) try: @@ -237,8 +236,8 @@ class RayCaster: self.ray_block_cache( cache_blocking, key, - # lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light), - lambda: any(e.is_blocking_light for e in entities.pos_dict[key]), + # lambda: all(False for e in entities.pos_dict[key] if not e.var_is_blocking_light), + lambda: any(e.var_is_blocking_light for e in entities.pos_dict[key]), entities) for key in ((x, y-cy), (x-cx, y)) ]) if (cx != 0 and cy != 0) else False diff --git a/marl_factory_grid/utils/renderer.py b/marl_factory_grid/utils/renderer.py index 84392d0..38a8e22 100644 --- a/marl_factory_grid/utils/renderer.py +++ b/marl_factory_grid/utils/renderer.py @@ -27,13 +27,13 @@ class Renderer: BG_COLOR = (178, 190, 195) # (99, 110, 114) WHITE = (223, 230, 233) # (200, 200, 200) AGENT_VIEW_COLOR = (9, 132, 227) - ASSETS = Path(__file__).parent.parent / 'assets' - MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules' + ASSETS = Path(__file__).parent.parent def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None, cell_size: int = 40, fps: int = 7, grid_lines: bool = True, view_radius: int = 2): + # TODO: Customn_assets paths self.grid_h, self.grid_w = lvl_shape self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape self.cell_size = cell_size @@ -44,7 +44,7 @@ class Renderer: self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size) self.screen = pygame.display.set_mode(self.screen_size) self.clock = pygame.time.Clock() - assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png')) + assets = list(self.ASSETS.rglob('*.png')) self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets} self.fill_bg() diff --git a/marl_factory_grid/utils/results.py b/marl_factory_grid/utils/results.py index 0b5a214..9f0fa38 100644 --- a/marl_factory_grid/utils/results.py +++ b/marl_factory_grid/utils/results.py @@ -1,8 +1,6 @@ from typing import Union from dataclasses import dataclass -from marl_factory_grid.environment.entity.entity import Entity - TYPE_VALUE = 'value' TYPE_REWARD = 'reward' types = [TYPE_VALUE, TYPE_REWARD] @@ -20,7 +18,7 @@ class Result: validity: bool reward: Union[float, None] = None value: Union[float, None] = None - entity: Union[Entity, None] = None + entity: None = None def get_infos(self): n = self.entity.name if self.entity is not None else "Global" diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index caee1a4..4eff5a7 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -2,10 +2,11 @@ from typing import List, Dict import numpy as np + +from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.wall_floor import Floor from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import Result -from marl_factory_grid.environment import constants as c class StepRules: @@ -26,9 +27,9 @@ class StepRules: self.rules.append(item) return True - def do_all_init(self, state): + def do_all_init(self, state, lvl_map): for rule in self.rules: - if rule_init_printline := rule.on_init(state): + if rule_init_printline := rule.on_init(state, lvl_map): state.print(rule_init_printline) return c.VALID @@ -58,7 +59,7 @@ class Gamestate(object): @property def moving_entites(self): - return [y for x in self.entities for y in x if x.can_move] + return [y for x in self.entities for y in x if x.var_can_move] def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False): self.entities = entitites @@ -107,6 +108,6 @@ class Gamestate(object): def get_all_tiles_with_collisions(self) -> List[Floor]: tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items() - if sum([x.can_collide for x in e]) > 1] + if sum([x.var_can_collide for x in e]) > 1] # tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1] return tiles diff --git a/marl_factory_grid/utils/tools.py b/marl_factory_grid/utils/tools.py index eff3e01..7fc5aa5 100644 --- a/marl_factory_grid/utils/tools.py +++ b/marl_factory_grid/utils/tools.py @@ -15,7 +15,7 @@ ENTITIES = 'Objects' OBSERVATIONS = 'Observations' RULES = 'Rule' ASSETS = 'Assets' -EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', +EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Floor', 'Agent', 'GlobalPositions', 'Walls', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ] diff --git a/marl_factory_grid/utils/utility_classes.py b/marl_factory_grid/utils/utility_classes.py index 302535a..8fee782 100644 --- a/marl_factory_grid/utils/utility_classes.py +++ b/marl_factory_grid/utils/utility_classes.py @@ -1,21 +1,6 @@ import gymnasium as gym -class EnvCombiner(object): - - def __init__(self, *envs_cls): - self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls} - - @staticmethod - def combine_cls(name, *envs_cls): - return type(name, envs_cls, {}) - - def build(self): - name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory' - - return self.combine_cls(name, tuple(self._env_dict.values())) - - class MarlFrameStack(gym.ObservationWrapper): """todo @romue404""" def __init__(self, env): diff --git a/reload_agent.py b/reload_agent.py index ddac858..99cddc8 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -3,7 +3,7 @@ from pathlib import Path import yaml -from marl_factory_grid.environment.factory import BaseFactory +from marl_factory_grid.environment.factory import Factory from marl_factory_grid.logging.envmonitor import EnvMonitor from marl_factory_grid.logging.recorder import EnvRecorder @@ -41,7 +41,7 @@ if __name__ == '__main__': pass # Init Env - with BaseFactory(**env_kwargs) as env: + with Factory(**env_kwargs) as env: env = EnvMonitor(env) env = EnvRecorder(env) if record else env obs_shape = env.observation_space.shape diff --git a/setup.py b/setup.py index 4749beb..2ec794c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text() setup(name='Marl-Factory-Grid', - version='0.0.11', + version='0.0.12', description='A framework to research MARL agents in various setings.', author='Steffen Illium', author_email='steffen.illium@ifi.lmu.de',