diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index f68e750..2f8ab20 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -14,8 +14,8 @@ build-job: # This job runs in the build stage, which runs first. image: python:slim script: - echo "Compiling the code..." - - pip install -U twine + - pip install twine --upgrade - python setup.py sdist - echo "Compile complete." - - twine upload dist/* + - twine upload dist/* --username $USER_NAME --password $API_KEY --repository marl-factory-grid - echo "Upload complete." diff --git a/marl_factory_grid/configs/default_config.yaml b/marl_factory_grid/configs/default_config.yaml index d3015c9..fe89597 100644 --- a/marl_factory_grid/configs/default_config.yaml +++ b/marl_factory_grid/configs/default_config.yaml @@ -22,13 +22,6 @@ Agents: - Inventory - DropOffLocations - Maintainers - # This is special for agents, as each one is differten and can act as an adversary e.g. - Positions: - - (16, 7) - - (16, 6) - - (16, 3) - - (16, 4) - - (16, 5) Entities: Batteries: initial_charge: 0.8 diff --git a/marl_factory_grid/configs/eight_puzzle.yaml b/marl_factory_grid/configs/eight_puzzle.yaml new file mode 100644 index 0000000..51d9164 --- /dev/null +++ b/marl_factory_grid/configs/eight_puzzle.yaml @@ -0,0 +1,55 @@ +Agents: + Wolfgang: + Actions: + - Noop + - Move4 + Observations: + - Other + - Walls + - Destination + Clones: + - Juergen + - Soeren + - Walter + - Siggi + - Dennis + - Karl-Heinz + - Kevin + is_blocking_pos: true +Entities: + Destinations: + # Let them spawn on closed doors and agent positions + ignore_blocking: true + # We need a special spawn rule... + spawnrule: + # ...which assigns the destinations per agent + SpawnDestinationsPerAgent: + # we use this parameter + coords_or_quantity: + # to enable and assign special positions per agent + Wolfgang: 1 + Karl-Heinz: 1 + Kevin: 1 + Juergen: 1 + Soeren: 1 + Walter: 1 + Siggi: 1 + Dennis: 1 +General: + env_seed: 69 + individual_rewards: true + level_name: eight_puzzle + pomdp_r: 3 + verbose: True + tests: false + +Rules: + # Utilities + WatchCollisions: + done_at_collisions: false + + # Done Conditions + DoneAtDestinationReach: + condition: simultanious + DoneAtMaxStepsReached: + max_steps: 500 diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index f37ce3a..10dbd8a 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -97,20 +97,26 @@ class Factory(gym.Env): return self.state.entities[item] def reset(self) -> (dict, dict): + + # Reset information the state holds + self.state.reset() + + # Reset Information the GlobalEntity collection holds. self.state.entities.reset() # All is set up, trigger entity spawn with variable pos self.state.rules.do_all_reset(self.state) # Build initial observations for all agents - return self.obs_builder.refresh_and_build_for_all(self.state) + self.obs_builder.reset(self.state) + return self.obs_builder.build_for_all(self.state) def manual_step_init(self) -> List[Result]: self.state.curr_step += 1 # Main Agent Step pre_step_result = self.state.rules.tick_pre_step_all(self) - self.obs_builder.reset_struc_obs_block(self.state) + self.obs_builder.reset(self.state) return pre_step_result def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray): @@ -164,7 +170,7 @@ class Factory(gym.Env): info.update(step_reward=sum(reward), step=self.state.curr_step) - obs = self.obs_builder.refresh_and_build_for_all(self.state) + obs = self.obs_builder.build_for_all(self.state) return None, [x for x in obs.values()], reward, done, info def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool): diff --git a/marl_factory_grid/environment/groups/agents.py b/marl_factory_grid/environment/groups/agents.py index d549384..5405ab1 100644 --- a/marl_factory_grid/environment/groups/agents.py +++ b/marl_factory_grid/environment/groups/agents.py @@ -1,6 +1,5 @@ from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.groups.collection import Collection -from marl_factory_grid.environment.rules import SpawnAgents class Agents(Collection): @@ -8,7 +7,7 @@ class Agents(Collection): @property def spawn_rule(self): - return {SpawnAgents.__name__: {}} + return {} @property def var_is_blocking_light(self): diff --git a/marl_factory_grid/environment/groups/global_entities.py b/marl_factory_grid/environment/groups/global_entities.py index 601ce4d..70ea9a8 100644 --- a/marl_factory_grid/environment/groups/global_entities.py +++ b/marl_factory_grid/environment/groups/global_entities.py @@ -27,7 +27,7 @@ class Entities(Objects): @property def floorlist(self): shuffle(self._floor_positions) - return self._floor_positions + return [x for x in self._floor_positions] def __init__(self, floor_positions): self._floor_positions = floor_positions diff --git a/marl_factory_grid/environment/rules.py b/marl_factory_grid/environment/rules.py index 5884cac..5c47df5 100644 --- a/marl_factory_grid/environment/rules.py +++ b/marl_factory_grid/environment/rules.py @@ -70,28 +70,22 @@ class SpawnAgents(Rule): def on_reset(self, state): agents = state[c.AGENT] - empty_positions = state.entities.empty_positions[:len(state.agents_conf)] for agent_name, agent_conf in state.agents_conf.items(): + empty_positions = state.entities.empty_positions actions = agent_conf['actions'].copy() observations = agent_conf['observations'].copy() positions = agent_conf['positions'].copy() other = agent_conf['other'].copy() - if positions: - shuffle(positions) - while True: - try: - pos = positions.pop() - except IndexError: - raise ValueError(f'It was not possible to spawn an Agent on the available position: ' - f'\n{agent_conf["positions"].copy()}') - if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos): - continue - else: - agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other)) - break + + if position := h.get_first(x for x in positions if x in empty_positions): + assert state.check_pos_validity(position), 'smth went wrong....' + agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other)) + elif positions: + raise ValueError(f'It was not possible to spawn an Agent on the available position: ' + f'\n{agent_conf["positions"].copy()}') else: agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other)) - pass + return [] class DoneAtMaxStepsReached(Rule): @@ -103,7 +97,7 @@ class DoneAtMaxStepsReached(Rule): def on_check_done(self, state): if self.max_steps <= state.curr_step: return [DoneResult(validity=c.VALID, identifier=self.name)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] + return [] class AssignGlobalPositions(Rule): @@ -130,7 +124,7 @@ class WatchCollisions(Rule): def tick_post_step(self, state) -> List[TickResult]: self.curr_done = False - pos_with_collisions = state.get_all_pos_with_collisions() + pos_with_collisions = state.get_collision_positions() results = list() for pos in pos_with_collisions: guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide] diff --git a/marl_factory_grid/levels/eight_puzzle.txt b/marl_factory_grid/levels/eight_puzzle.txt new file mode 100644 index 0000000..7b77a38 --- /dev/null +++ b/marl_factory_grid/levels/eight_puzzle.txt @@ -0,0 +1,5 @@ +##### +#---# +#---# +#---# +##### \ No newline at end of file diff --git a/marl_factory_grid/modules/batteries/rules.py b/marl_factory_grid/modules/batteries/rules.py index 7314c93..4e6d892 100644 --- a/marl_factory_grid/modules/batteries/rules.py +++ b/marl_factory_grid/modules/batteries/rules.py @@ -60,7 +60,7 @@ class BatteryDecharge(Rule): batteries.by_entity(agent).decharge(energy_consumption) - results.append(TickResult(self.name, entity=agent, validity=c.VALID)) + results.append(TickResult(self.name, entity=agent, validity=c.VALID, value=energy_consumption)) return results diff --git a/marl_factory_grid/modules/clean_up/rules.py b/marl_factory_grid/modules/clean_up/rules.py index b81ee41..2f69f9e 100644 --- a/marl_factory_grid/modules/clean_up/rules.py +++ b/marl_factory_grid/modules/clean_up/rules.py @@ -22,7 +22,7 @@ class DoneOnAllDirtCleaned(Rule): def on_check_done(self, state) -> [DoneResult]: if len(state[d.DIRT]) == 0 and state.curr_step: return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] - return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] + return [] class RespawnDirt(Rule): @@ -81,5 +81,6 @@ class EntitiesSmearDirtOnMove(Rule): old_pos_dirt = next(iter(old_pos_dirt)) if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): - results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID)) + results.append(TickResult(identifier=self.name, entity=entity, + validity=c.VALID, value=smeared_dirt)) return results diff --git a/marl_factory_grid/modules/destinations/__init__.py b/marl_factory_grid/modules/destinations/__init__.py index 4614dd7..9072a5f 100644 --- a/marl_factory_grid/modules/destinations/__init__.py +++ b/marl_factory_grid/modules/destinations/__init__.py @@ -1,7 +1,4 @@ from .actions import DestAction from .entitites import Destination from .groups import Destinations -from .rules import (DoneAtDestinationReachAll, - DoneAtDestinationReachAny, - SpawnDestinationsPerAgent, - DestinationReachReward) +from .rules import (DoneAtDestinationReach, SpawnDestinationsPerAgent, DestinationReachReward) diff --git a/marl_factory_grid/modules/destinations/entitites.py b/marl_factory_grid/modules/destinations/entitites.py index d75f9e0..9e11db0 100644 --- a/marl_factory_grid/modules/destinations/entitites.py +++ b/marl_factory_grid/modules/destinations/entitites.py @@ -54,3 +54,6 @@ class Destination(Entity): def mark_as_reached(self): self._was_reached = True + + def unmark_as_reached(self): + self._was_reached = False diff --git a/marl_factory_grid/modules/destinations/rules.py b/marl_factory_grid/modules/destinations/rules.py index ef004c3..89a1ad8 100644 --- a/marl_factory_grid/modules/destinations/rules.py +++ b/marl_factory_grid/modules/destinations/rules.py @@ -9,6 +9,13 @@ from marl_factory_grid.environment import constants as c from marl_factory_grid.modules.destinations import constants as d from marl_factory_grid.modules.destinations.entitites import Destination +from marl_factory_grid.utils.states import Gamestate + + +ANY = 'any' +ALL = 'all' +SIMULTANOIUS = 'simultanious' +CONDITIONS =[ALL, ANY, SIMULTANOIUS] class DestinationReachReward(Rule): @@ -48,9 +55,9 @@ class DestinationReachReward(Rule): return results -class DoneAtDestinationReachAll(DestinationReachReward): +class DoneAtDestinationReach(DestinationReachReward): - def __init__(self, reward_at_done=d.REWARD_DEST_DONE, **kwargs): + def __init__(self, condition='any', reward_at_done=d.REWARD_DEST_DONE, **kwargs): """ This rule triggers and sets the done flag if ALL Destinations have been reached. @@ -59,68 +66,79 @@ class DoneAtDestinationReachAll(DestinationReachReward): :type dest_reach_reward: float :param dest_reach_reward: Specify the reward, agents get when reaching a single destination. """ - super(DoneAtDestinationReachAll, self).__init__(**kwargs) + super().__init__(**kwargs) + self.condition = condition self.reward = reward_at_done + assert condition in CONDITIONS def on_check_done(self, state) -> List[DoneResult]: - if all(x.was_reached() for x in state[d.DESTINATION]): - return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] - return [DoneResult(self.name, validity=c.NOT_VALID)] - - -class DoneAtDestinationReachAny(DestinationReachReward): - - def __init__(self, reward_at_done=d.REWARD_DEST_DONE, **kwargs): - f""" - This rule triggers and sets the done flag if ANY Destinations has been reached. - !!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. - - :type reward_at_done: float - :param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached. - Default {d.REWARD_DEST_DONE} - :type dest_reach_reward: float - :param dest_reach_reward: Specify a single agents reward forreaching a single destination. - Default {d.REWARD_DEST_REACHED} - """ - super(DoneAtDestinationReachAny, self).__init__(**kwargs) - self.reward = reward_at_done - - def on_check_done(self, state) -> List[DoneResult]: - if any(x.was_reached() for x in state[d.DESTINATION]): - return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)] - return [] + if self.condition == ANY: + if any(x.was_reached() for x in state[d.DESTINATION]): + return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] + elif self.condition == ALL: + if all(x.was_reached() for x in state[d.DESTINATION]): + return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] + elif self.condition == SIMULTANOIUS: + if all(x.was_reached() for x in state[d.DESTINATION]): + return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] + else: + for dest in state[d.DESTINATION]: + if dest.was_reached(): + for agent in state[c.AGENT].by_pos(dest.pos): + if dest.bound_entity: + if dest.bound_entity == agent: + pass + else: + dest.unmark_as_reached() + return [DoneResult(f'{dest}_unmarked_as_reached', + validity=c.NOT_VALID, entity=dest)] + else: + pass + else: + raise ValueError('Check spelling of Parameter "condition".') class SpawnDestinationsPerAgent(Rule): - def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]): + def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int] | int]]): """ Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. Usefull for introducing specialists, etc. .. !!! This rule does not introduce any reward or done condition. - :type coords_or_quantity: Dict[str, List[Tuple[int, int]] :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} """ super(Rule, self).__init__() - self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} + self.per_agent_positions = dict() + for agent_name, value in coords_or_quantity.items(): + if isinstance(value, int): + per_agent_d = {agent_name: value} + else: + per_agent_d = {agent_name: [ast.literal_eval(x) for x in value]} + self.per_agent_positions.update(**per_agent_d) - def on_reset(self, state, lvl_map): - for (agent_name, position_list) in self.per_agent_positions.items(): + def on_reset(self, state: Gamestate): + for (agent_name, coords_or_quantity) in self.per_agent_positions.items(): agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) assert agent - position_list = position_list.copy() + if isinstance(coords_or_quantity, int): + position_list = state.entities.floorlist + pos_left_counter = coords_or_quantity + else: + position_list = coords_or_quantity.copy() + pos_left_counter = 1 # Find a better way to resolve this. shuffle(position_list) - while True: + while pos_left_counter: try: pos = position_list.pop() except IndexError: print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...') - exit(9999) + exit(-9999) if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): destination = Destination(pos, bind_to=agent) + pos_left_counter -= 1 break else: continue diff --git a/marl_factory_grid/modules/maintenance/rules.py b/marl_factory_grid/modules/maintenance/rules.py index 92e6e75..bb5d70c 100644 --- a/marl_factory_grid/modules/maintenance/rules.py +++ b/marl_factory_grid/modules/maintenance/rules.py @@ -1,6 +1,5 @@ from typing import List -import marl_factory_grid.modules.maintenance.constants from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.environment import constants as c @@ -31,5 +30,5 @@ class DoneAtMaintainerCollision(Rule): for agent in agents: if agent.pos in m_pos: done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, - reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD)) + reward=M.MAINTAINER_COLLISION_REWARD)) return done_results diff --git a/marl_factory_grid/modules/zones/rules.py b/marl_factory_grid/modules/zones/rules.py index c31666c..a52de61 100644 --- a/marl_factory_grid/modules/zones/rules.py +++ b/marl_factory_grid/modules/zones/rules.py @@ -43,9 +43,6 @@ class AgentSingleZonePlacement(Rule): agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state) return [] - def tick_step(self, state): - return [] - class IndividualDestinationZonePlacement(Rule): diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index f54c9ab..c83d57d 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -1,4 +1,5 @@ import ast +from collections import defaultdict from os import PathLike from pathlib import Path @@ -24,13 +25,21 @@ class FactoryConfigParser(object): self.config_path = Path(config_path) self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.config = yaml.safe_load(self.config_path.open()) + self._n_abbr_dict = None def __getattr__(self, item): return self['General'][item] def _get_sub_list(self, primary_key: str, sub_key: str): return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items() - } for x in self.config[primary_key]] + } for x in self.config.get(primary_key, [])] + + def _n_abbr(self, n): + assert isinstance(n, int) + if self._n_abbr_dict is None: + self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'}) + return self._n_abbr_dict[n] + @property def agent_actions(self): @@ -145,11 +154,18 @@ class FactoryConfigParser(object): observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] other_kwargs = {k: v for k, v in self.agents[name].items() if k not in - ['Actions', 'Observations', 'Positions']} + ['Actions', 'Observations', 'Positions', 'Clones']} parsed_agents_conf[name] = dict( actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs ) + clones = self.agents[name].get('Clones', 0) + if clones: + if isinstance(clones, int): + clones = [f'{name}_the_{n}{self._n_abbr(n)}' for n in range(clones)] + for clone in clones: + parsed_agents_conf[clone] = parsed_agents_conf[name].copy() + return parsed_agents_conf def load_env_rules(self) -> List[Rule]: diff --git a/marl_factory_grid/utils/logging/envmonitor.py b/marl_factory_grid/utils/logging/envmonitor.py index e2551c8..92af5ec 100644 --- a/marl_factory_grid/utils/logging/envmonitor.py +++ b/marl_factory_grid/utils/logging/envmonitor.py @@ -58,3 +58,6 @@ class EnvMonitor(Wrapper): pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL) if auto_plotting_keys: plot_single_run(filepath, column_keys=auto_plotting_keys) + + def report_possible_colum_keys(self): + print(self._monitor_df.columns) \ No newline at end of file diff --git a/marl_factory_grid/utils/observation_builder.py b/marl_factory_grid/utils/observation_builder.py index 55d6ec0..35497bf 100644 --- a/marl_factory_grid/utils/observation_builder.py +++ b/marl_factory_grid/utils/observation_builder.py @@ -24,11 +24,7 @@ class OBSBuilder(object): return 0 def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int): - self._curr_env_step = None self.all_obs = dict() - self.light_blockers = defaultdict(lambda: False) - self.positional = defaultdict(lambda: False) - self.non_positional = defaultdict(lambda: False) self.ray_caster = dict() self.level_shape = level_shape @@ -37,13 +33,15 @@ class OBSBuilder(object): self.size = np.prod(self.obs_shape) self.obs_layers = dict() - - self.reset_struc_obs_block(state) self.curr_lightmaps = dict() + self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist}) - def reset_struc_obs_block(self, state): - self._curr_env_step = state.curr_step + self.reset(state) + + def reset(self, state): + # Reset temporary information + self.curr_lightmaps = dict() # Construct an empty obs (array) for possible placeholders self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float) # Fill the all_obs-dict with all available entities @@ -52,7 +50,8 @@ class OBSBuilder(object): def observation_space(self, state): from gymnasium.spaces import Tuple, Box - obsn = self.refresh_and_build_for_all(state) + self.reset(state) + obsn = self.build_for_all(state) if len(state[c.AGENT]) == 1: space = Box(low=0, high=1, shape=next(x for x in obsn.values()).shape, dtype=np.float32) else: @@ -60,14 +59,13 @@ class OBSBuilder(object): return space def named_observation_space(self, state): - return self.refresh_and_build_for_all(state) + self.reset(state) + return self.build_for_all(state) - def refresh_and_build_for_all(self, state) -> (dict, dict): - self.reset_struc_obs_block(state) + def build_for_all(self, state) -> (dict, dict): return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]} - def refresh_and_build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]: - self.reset_struc_obs_block(state) + def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]: named_obs_dict = {} for agent in state[c.AGENT]: obs, names = self.build_for_agent(agent, state) @@ -85,9 +83,6 @@ class OBSBuilder(object): pass def build_for_agent(self, agent, state) -> (List[str], np.ndarray): - assert self._curr_env_step == state.curr_step, ( - "The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'" - ) try: agent_want_obs = self.obs_layers[agent.name] except KeyError: @@ -166,7 +161,8 @@ class OBSBuilder(object): raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.') if self.pomdp_r: try: - light_map = np.zeros(self.obs_shape) + light_map = self.curr_lightmaps.get(agent.name, np.zeros(self.obs_shape)) + light_map[:] = 0.0 visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) for f in set(visible_floor): diff --git a/marl_factory_grid/utils/plotting/plotting_utils.py b/marl_factory_grid/utils/plotting/plotting_utils.py index 17bb7ff..2ae61f0 100644 --- a/marl_factory_grid/utils/plotting/plotting_utils.py +++ b/marl_factory_grid/utils/plotting/plotting_utils.py @@ -49,7 +49,7 @@ def prepare_plt(df, hue, style, hue_order): plt.close('all') sns.set(rc={'text.usetex': False}, style='whitegrid') lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, - ci=95, palette=PALETTE, hue_order=hue_order, ) + errorbar=('ci', 95), palette=PALETTE, hue_order=hue_order, ) plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) plt.tight_layout() # lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 8baf012..fa5bc4e 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -8,7 +8,7 @@ import numpy as np from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.entity import Entity -from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.environment.rules import Rule, SpawnAgents from marl_factory_grid.utils.results import Result, DoneResult from marl_factory_grid.environment.tests import Test from marl_factory_grid.utils.results import Result @@ -32,18 +32,19 @@ class StepRules: self.rules.append(item) return True - def do_all_reset(self, state): - for rule in self.rules: - if rule_reset_printline := rule.on_reset(state): - state.print(rule_reset_printline) - return c.VALID - def do_all_init(self, state, lvl_map): for rule in self.rules: if rule_init_printline := rule.on_init(state, lvl_map): state.print(rule_init_printline) return c.VALID + def do_all_reset(self, state): + SpawnAgents().on_reset(state) + for rule in self.rules: + if rule_reset_printline := rule.on_reset(state): + state.print(rule_reset_printline) + return c.VALID + def tick_step_all(self, state): results = list() for rule in self.rules: @@ -91,6 +92,10 @@ class Gamestate(object): self._floortile_graph = None self.tests = StepTests(*tests) + def reset(self): + self.curr_step = 0 + self.curr_actions = None + def __getitem__(self, item): return self.entities[item] @@ -201,7 +206,7 @@ class Gamestate(object): results.extend(on_check_done_result) return results - def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: + def get_collision_positions(self) -> List[Tuple[(int, int)]]: """ Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, that were unable to move because their target direction was blocked, also a form of collision. diff --git a/random_testrun.py b/random_testrun.py index ef8df08..e727ea9 100644 --- a/random_testrun.py +++ b/random_testrun.py @@ -12,7 +12,7 @@ from marl_factory_grid.utils.tools import ConfigExplainer if __name__ == '__main__': # Render at each step? - render = False + render = True # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) explain_config = False # Collect statistics? @@ -29,7 +29,7 @@ if __name__ == '__main__': ce.save_all(run_path / 'all_out.yaml') # Path to config File - path = Path('marl_factory_grid/configs/narrow_corridor.yaml') + path = Path('marl_factory_grid/configs/eight_puzzle.yaml') # Env Init factory = Factory(path) @@ -61,6 +61,10 @@ if __name__ == '__main__': if record: factory.save_records(run_path / 'test.pb') if plotting: - plot_single_run(run_path) + factory.report_possible_colum_keys() + plot_single_run(run_path, column_keys=['Global_DoneAtDestinationReachAll', 'step_reward', + 'Agent[Karl-Heinz]_DoneAtDestinationReachAll', + 'Agent[Wolfgang]_DoneAtDestinationReachAll', + 'Global_DoneAtDestinationReachAll']) print('Done!!! Goodbye....') diff --git a/setup.py b/setup.py index 6d12a87..c51836f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text() setup(name='Marl-Factory-Grid', - version='0.1.2', + version='0.2.0', description='A framework to research MARL agents in various setings.', author='Steffen Illium', author_email='steffen.illium@ifi.lmu.de',