Merge branch 'main' into unit_testing

This commit is contained in:
Chanumask
2023-11-13 11:00:14 +01:00
22 changed files with 205 additions and 114 deletions

View File

@@ -14,8 +14,8 @@ build-job: # This job runs in the build stage, which runs first.
image: python:slim image: python:slim
script: script:
- echo "Compiling the code..." - echo "Compiling the code..."
- pip install -U twine - pip install twine --upgrade
- python setup.py sdist - python setup.py sdist
- echo "Compile complete." - echo "Compile complete."
- twine upload dist/* - twine upload dist/* --username $USER_NAME --password $API_KEY --repository marl-factory-grid
- echo "Upload complete." - echo "Upload complete."

View File

@@ -22,13 +22,6 @@ Agents:
- Inventory - Inventory
- DropOffLocations - DropOffLocations
- Maintainers - Maintainers
# This is special for agents, as each one is differten and can act as an adversary e.g.
Positions:
- (16, 7)
- (16, 6)
- (16, 3)
- (16, 4)
- (16, 5)
Entities: Entities:
Batteries: Batteries:
initial_charge: 0.8 initial_charge: 0.8

View File

@@ -0,0 +1,55 @@
Agents:
Wolfgang:
Actions:
- Noop
- Move4
Observations:
- Other
- Walls
- Destination
Clones:
- Juergen
- Soeren
- Walter
- Siggi
- Dennis
- Karl-Heinz
- Kevin
is_blocking_pos: true
Entities:
Destinations:
# Let them spawn on closed doors and agent positions
ignore_blocking: true
# We need a special spawn rule...
spawnrule:
# ...which assigns the destinations per agent
SpawnDestinationsPerAgent:
# we use this parameter
coords_or_quantity:
# to enable and assign special positions per agent
Wolfgang: 1
Karl-Heinz: 1
Kevin: 1
Juergen: 1
Soeren: 1
Walter: 1
Siggi: 1
Dennis: 1
General:
env_seed: 69
individual_rewards: true
level_name: eight_puzzle
pomdp_r: 3
verbose: True
tests: false
Rules:
# Utilities
WatchCollisions:
done_at_collisions: false
# Done Conditions
DoneAtDestinationReach:
condition: simultanious
DoneAtMaxStepsReached:
max_steps: 500

View File

@@ -97,20 +97,26 @@ class Factory(gym.Env):
return self.state.entities[item] return self.state.entities[item]
def reset(self) -> (dict, dict): def reset(self) -> (dict, dict):
# Reset information the state holds
self.state.reset()
# Reset Information the GlobalEntity collection holds.
self.state.entities.reset() self.state.entities.reset()
# All is set up, trigger entity spawn with variable pos # All is set up, trigger entity spawn with variable pos
self.state.rules.do_all_reset(self.state) self.state.rules.do_all_reset(self.state)
# Build initial observations for all agents # Build initial observations for all agents
return self.obs_builder.refresh_and_build_for_all(self.state) self.obs_builder.reset(self.state)
return self.obs_builder.build_for_all(self.state)
def manual_step_init(self) -> List[Result]: def manual_step_init(self) -> List[Result]:
self.state.curr_step += 1 self.state.curr_step += 1
# Main Agent Step # Main Agent Step
pre_step_result = self.state.rules.tick_pre_step_all(self) pre_step_result = self.state.rules.tick_pre_step_all(self)
self.obs_builder.reset_struc_obs_block(self.state) self.obs_builder.reset(self.state)
return pre_step_result return pre_step_result
def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray): def manual_get_named_agent_obs(self, agent_name: str) -> (List[str], np.ndarray):
@@ -164,7 +170,7 @@ class Factory(gym.Env):
info.update(step_reward=sum(reward), step=self.state.curr_step) info.update(step_reward=sum(reward), step=self.state.curr_step)
obs = self.obs_builder.refresh_and_build_for_all(self.state) obs = self.obs_builder.build_for_all(self.state)
return None, [x for x in obs.values()], reward, done, info return None, [x for x in obs.values()], reward, done, info
def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool): def summarize_step_results(self, tick_results: list, done_check_results: list) -> (int, dict, bool):

View File

@@ -1,6 +1,5 @@
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.groups.collection import Collection from marl_factory_grid.environment.groups.collection import Collection
from marl_factory_grid.environment.rules import SpawnAgents
class Agents(Collection): class Agents(Collection):
@@ -8,7 +7,7 @@ class Agents(Collection):
@property @property
def spawn_rule(self): def spawn_rule(self):
return {SpawnAgents.__name__: {}} return {}
@property @property
def var_is_blocking_light(self): def var_is_blocking_light(self):

View File

@@ -27,7 +27,7 @@ class Entities(Objects):
@property @property
def floorlist(self): def floorlist(self):
shuffle(self._floor_positions) shuffle(self._floor_positions)
return self._floor_positions return [x for x in self._floor_positions]
def __init__(self, floor_positions): def __init__(self, floor_positions):
self._floor_positions = floor_positions self._floor_positions = floor_positions

View File

@@ -70,28 +70,22 @@ class SpawnAgents(Rule):
def on_reset(self, state): def on_reset(self, state):
agents = state[c.AGENT] agents = state[c.AGENT]
empty_positions = state.entities.empty_positions[:len(state.agents_conf)]
for agent_name, agent_conf in state.agents_conf.items(): for agent_name, agent_conf in state.agents_conf.items():
empty_positions = state.entities.empty_positions
actions = agent_conf['actions'].copy() actions = agent_conf['actions'].copy()
observations = agent_conf['observations'].copy() observations = agent_conf['observations'].copy()
positions = agent_conf['positions'].copy() positions = agent_conf['positions'].copy()
other = agent_conf['other'].copy() other = agent_conf['other'].copy()
if positions:
shuffle(positions) if position := h.get_first(x for x in positions if x in empty_positions):
while True: assert state.check_pos_validity(position), 'smth went wrong....'
try: agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
pos = positions.pop() elif positions:
except IndexError:
raise ValueError(f'It was not possible to spawn an Agent on the available position: ' raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_conf["positions"].copy()}') f'\n{agent_conf["positions"].copy()}')
if bool(agents.by_pos(pos)) or not state.check_pos_validity(pos):
continue
else:
agents.add_item(Agent(actions, observations, pos, str_ident=agent_name, **other))
break
else: else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other)) agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
pass return []
class DoneAtMaxStepsReached(Rule): class DoneAtMaxStepsReached(Rule):
@@ -103,7 +97,7 @@ class DoneAtMaxStepsReached(Rule):
def on_check_done(self, state): def on_check_done(self, state):
if self.max_steps <= state.curr_step: if self.max_steps <= state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name)] return [DoneResult(validity=c.VALID, identifier=self.name)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] return []
class AssignGlobalPositions(Rule): class AssignGlobalPositions(Rule):
@@ -130,7 +124,7 @@ class WatchCollisions(Rule):
def tick_post_step(self, state) -> List[TickResult]: def tick_post_step(self, state) -> List[TickResult]:
self.curr_done = False self.curr_done = False
pos_with_collisions = state.get_all_pos_with_collisions() pos_with_collisions = state.get_collision_positions()
results = list() results = list()
for pos in pos_with_collisions: for pos in pos_with_collisions:
guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide] guests = [x for x in state.entities.pos_dict[pos] if x.var_can_collide]

View File

@@ -0,0 +1,5 @@
#####
#---#
#---#
#---#
#####

View File

@@ -60,7 +60,7 @@ class BatteryDecharge(Rule):
batteries.by_entity(agent).decharge(energy_consumption) batteries.by_entity(agent).decharge(energy_consumption)
results.append(TickResult(self.name, entity=agent, validity=c.VALID)) results.append(TickResult(self.name, entity=agent, validity=c.VALID, value=energy_consumption))
return results return results

View File

@@ -22,7 +22,7 @@ class DoneOnAllDirtCleaned(Rule):
def on_check_done(self, state) -> [DoneResult]: def on_check_done(self, state) -> [DoneResult]:
if len(state[d.DIRT]) == 0 and state.curr_step: if len(state[d.DIRT]) == 0 and state.curr_step:
return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)] return [DoneResult(validity=c.VALID, identifier=self.name, reward=self.reward)]
return [DoneResult(validity=c.NOT_VALID, identifier=self.name)] return []
class RespawnDirt(Rule): class RespawnDirt(Rule):
@@ -81,5 +81,6 @@ class EntitiesSmearDirtOnMove(Rule):
old_pos_dirt = next(iter(old_pos_dirt)) old_pos_dirt = next(iter(old_pos_dirt))
if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2): if smeared_dirt := round(old_pos_dirt.amount * self.smear_ratio, 2):
if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt): if state[d.DIRT].spawn(entity.pos, amount=smeared_dirt):
results.append(TickResult(identifier=self.name, entity=entity, validity=c.VALID)) results.append(TickResult(identifier=self.name, entity=entity,
validity=c.VALID, value=smeared_dirt))
return results return results

View File

@@ -1,7 +1,4 @@
from .actions import DestAction from .actions import DestAction
from .entitites import Destination from .entitites import Destination
from .groups import Destinations from .groups import Destinations
from .rules import (DoneAtDestinationReachAll, from .rules import (DoneAtDestinationReach, SpawnDestinationsPerAgent, DestinationReachReward)
DoneAtDestinationReachAny,
SpawnDestinationsPerAgent,
DestinationReachReward)

View File

@@ -54,3 +54,6 @@ class Destination(Entity):
def mark_as_reached(self): def mark_as_reached(self):
self._was_reached = True self._was_reached = True
def unmark_as_reached(self):
self._was_reached = False

View File

@@ -9,6 +9,13 @@ from marl_factory_grid.environment import constants as c
from marl_factory_grid.modules.destinations import constants as d from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.destinations.entitites import Destination from marl_factory_grid.modules.destinations.entitites import Destination
from marl_factory_grid.utils.states import Gamestate
ANY = 'any'
ALL = 'all'
SIMULTANOIUS = 'simultanious'
CONDITIONS =[ALL, ANY, SIMULTANOIUS]
class DestinationReachReward(Rule): class DestinationReachReward(Rule):
@@ -48,9 +55,9 @@ class DestinationReachReward(Rule):
return results return results
class DoneAtDestinationReachAll(DestinationReachReward): class DoneAtDestinationReach(DestinationReachReward):
def __init__(self, reward_at_done=d.REWARD_DEST_DONE, **kwargs): def __init__(self, condition='any', reward_at_done=d.REWARD_DEST_DONE, **kwargs):
""" """
This rule triggers and sets the done flag if ALL Destinations have been reached. This rule triggers and sets the done flag if ALL Destinations have been reached.
@@ -59,68 +66,79 @@ class DoneAtDestinationReachAll(DestinationReachReward):
:type dest_reach_reward: float :type dest_reach_reward: float
:param dest_reach_reward: Specify the reward, agents get when reaching a single destination. :param dest_reach_reward: Specify the reward, agents get when reaching a single destination.
""" """
super(DoneAtDestinationReachAll, self).__init__(**kwargs) super().__init__(**kwargs)
self.condition = condition
self.reward = reward_at_done self.reward = reward_at_done
assert condition in CONDITIONS
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if self.condition == ANY:
if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
elif self.condition == ALL:
if all(x.was_reached() for x in state[d.DESTINATION]): if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
return [DoneResult(self.name, validity=c.NOT_VALID)] elif self.condition == SIMULTANOIUS:
if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
class DoneAtDestinationReachAny(DestinationReachReward): else:
for dest in state[d.DESTINATION]:
def __init__(self, reward_at_done=d.REWARD_DEST_DONE, **kwargs): if dest.was_reached():
f""" for agent in state[c.AGENT].by_pos(dest.pos):
This rule triggers and sets the done flag if ANY Destinations has been reached. if dest.bound_entity:
!!! IMPORTANT: 'reward_at_done' is shared between the agents; 'dest_reach_reward' is bound to a specific one. if dest.bound_entity == agent:
pass
:type reward_at_done: float else:
:param reward_at_done: Specifies the reward, all agent get, when any destinations has been reached. dest.unmark_as_reached()
Default {d.REWARD_DEST_DONE} return [DoneResult(f'{dest}_unmarked_as_reached',
:type dest_reach_reward: float validity=c.NOT_VALID, entity=dest)]
:param dest_reach_reward: Specify a single agents reward forreaching a single destination. else:
Default {d.REWARD_DEST_REACHED} pass
""" else:
super(DoneAtDestinationReachAny, self).__init__(**kwargs) raise ValueError('Check spelling of Parameter "condition".')
self.reward = reward_at_done
def on_check_done(self, state) -> List[DoneResult]:
if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=d.REWARD_DEST_REACHED)]
return []
class SpawnDestinationsPerAgent(Rule): class SpawnDestinationsPerAgent(Rule):
def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int]]]): def __init__(self, coords_or_quantity: Dict[str, List[Tuple[int, int] | int]]):
""" """
Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions. Special rule, that spawn distinations, that are bound to a single agent a fixed set of positions.
Usefull for introducing specialists, etc. .. Usefull for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition. !!! This rule does not introduce any reward or done condition.
:type coords_or_quantity: Dict[str, List[Tuple[int, int]]
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} destiantion coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
""" """
super(Rule, self).__init__() super(Rule, self).__init__()
self.per_agent_positions = {key: [ast.literal_eval(x) for x in val] for key, val in coords_or_quantity.items()} self.per_agent_positions = dict()
for agent_name, value in coords_or_quantity.items():
if isinstance(value, int):
per_agent_d = {agent_name: value}
else:
per_agent_d = {agent_name: [ast.literal_eval(x) for x in value]}
self.per_agent_positions.update(**per_agent_d)
def on_reset(self, state, lvl_map): def on_reset(self, state: Gamestate):
for (agent_name, position_list) in self.per_agent_positions.items(): for (agent_name, coords_or_quantity) in self.per_agent_positions.items():
agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name) agent = h.get_first(state[c.AGENT], lambda x: agent_name in x.name)
assert agent assert agent
position_list = position_list.copy() if isinstance(coords_or_quantity, int):
position_list = state.entities.floorlist
pos_left_counter = coords_or_quantity
else:
position_list = coords_or_quantity.copy()
pos_left_counter = 1 # Find a better way to resolve this.
shuffle(position_list) shuffle(position_list)
while True: while pos_left_counter:
try: try:
pos = position_list.pop() pos = position_list.pop()
except IndexError: except IndexError:
print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}") print(f"Could not spawn Destinations at: {self.per_agent_positions[agent_name]}")
print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...') print(f'Check your agent placement: {state[c.AGENT]} ... Exit ...')
exit(9999) exit(-9999)
if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)): if (not pos == agent.pos) and (not state[d.DESTINATION].by_pos(pos)):
destination = Destination(pos, bind_to=agent) destination = Destination(pos, bind_to=agent)
pos_left_counter -= 1
break break
else: else:
continue continue

View File

@@ -1,6 +1,5 @@
from typing import List from typing import List
import marl_factory_grid.modules.maintenance.constants
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
@@ -31,5 +30,5 @@ class DoneAtMaintainerCollision(Rule):
for agent in agents: for agent in agents:
if agent.pos in m_pos: if agent.pos in m_pos:
done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name, done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
reward=marl_factory_grid.modules.maintenance.constants.MAINTAINER_COLLISION_REWARD)) reward=M.MAINTAINER_COLLISION_REWARD))
return done_results return done_results

View File

@@ -43,9 +43,6 @@ class AgentSingleZonePlacement(Rule):
agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state) agent.move(state[z.ZONES][z_idxs.pop()].random_pos, state)
return [] return []
def tick_step(self, state):
return []
class IndividualDestinationZonePlacement(Rule): class IndividualDestinationZonePlacement(Rule):

View File

@@ -1,4 +1,5 @@
import ast import ast
from collections import defaultdict
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
@@ -24,13 +25,21 @@ class FactoryConfigParser(object):
self.config_path = Path(config_path) self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open()) self.config = yaml.safe_load(self.config_path.open())
self._n_abbr_dict = None
def __getattr__(self, item): def __getattr__(self, item):
return self['General'][item] return self['General'][item]
def _get_sub_list(self, primary_key: str, sub_key: str): def _get_sub_list(self, primary_key: str, sub_key: str):
return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items() return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items()
} for x in self.config[primary_key]] } for x in self.config.get(primary_key, [])]
def _n_abbr(self, n):
assert isinstance(n, int)
if self._n_abbr_dict is None:
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
return self._n_abbr_dict[n]
@property @property
def agent_actions(self): def agent_actions(self):
@@ -145,11 +154,18 @@ class FactoryConfigParser(object):
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions']} ['Actions', 'Observations', 'Positions', 'Clones']}
parsed_agents_conf[name] = dict( parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
) )
clones = self.agents[name].get('Clones', 0)
if clones:
if isinstance(clones, int):
clones = [f'{name}_the_{n}{self._n_abbr(n)}' for n in range(clones)]
for clone in clones:
parsed_agents_conf[clone] = parsed_agents_conf[name].copy()
return parsed_agents_conf return parsed_agents_conf
def load_env_rules(self) -> List[Rule]: def load_env_rules(self) -> List[Rule]:

View File

@@ -58,3 +58,6 @@ class EnvMonitor(Wrapper):
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
if auto_plotting_keys: if auto_plotting_keys:
plot_single_run(filepath, column_keys=auto_plotting_keys) plot_single_run(filepath, column_keys=auto_plotting_keys)
def report_possible_colum_keys(self):
print(self._monitor_df.columns)

View File

@@ -24,11 +24,7 @@ class OBSBuilder(object):
return 0 return 0
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int): def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
self._curr_env_step = None
self.all_obs = dict() self.all_obs = dict()
self.light_blockers = defaultdict(lambda: False)
self.positional = defaultdict(lambda: False)
self.non_positional = defaultdict(lambda: False)
self.ray_caster = dict() self.ray_caster = dict()
self.level_shape = level_shape self.level_shape = level_shape
@@ -37,13 +33,15 @@ class OBSBuilder(object):
self.size = np.prod(self.obs_shape) self.size = np.prod(self.obs_shape)
self.obs_layers = dict() self.obs_layers = dict()
self.reset_struc_obs_block(state)
self.curr_lightmaps = dict() self.curr_lightmaps = dict()
self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist}) self._floortiles = defaultdict(list, {pos: [Floor(*pos)] for pos in state.entities.floorlist})
def reset_struc_obs_block(self, state): self.reset(state)
self._curr_env_step = state.curr_step
def reset(self, state):
# Reset temporary information
self.curr_lightmaps = dict()
# Construct an empty obs (array) for possible placeholders # Construct an empty obs (array) for possible placeholders
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float) self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
# Fill the all_obs-dict with all available entities # Fill the all_obs-dict with all available entities
@@ -52,7 +50,8 @@ class OBSBuilder(object):
def observation_space(self, state): def observation_space(self, state):
from gymnasium.spaces import Tuple, Box from gymnasium.spaces import Tuple, Box
obsn = self.refresh_and_build_for_all(state) self.reset(state)
obsn = self.build_for_all(state)
if len(state[c.AGENT]) == 1: if len(state[c.AGENT]) == 1:
space = Box(low=0, high=1, shape=next(x for x in obsn.values()).shape, dtype=np.float32) space = Box(low=0, high=1, shape=next(x for x in obsn.values()).shape, dtype=np.float32)
else: else:
@@ -60,14 +59,13 @@ class OBSBuilder(object):
return space return space
def named_observation_space(self, state): def named_observation_space(self, state):
return self.refresh_and_build_for_all(state) self.reset(state)
return self.build_for_all(state)
def refresh_and_build_for_all(self, state) -> (dict, dict): def build_for_all(self, state) -> (dict, dict):
self.reset_struc_obs_block(state)
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]} return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
def refresh_and_build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]: def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
self.reset_struc_obs_block(state)
named_obs_dict = {} named_obs_dict = {}
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
obs, names = self.build_for_agent(agent, state) obs, names = self.build_for_agent(agent, state)
@@ -85,9 +83,6 @@ class OBSBuilder(object):
pass pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray): def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
assert self._curr_env_step == state.curr_step, (
"The observation objekt has not been reset this state! Call 'reset_struc_obs_block(state)'"
)
try: try:
agent_want_obs = self.obs_layers[agent.name] agent_want_obs = self.obs_layers[agent.name]
except KeyError: except KeyError:
@@ -166,7 +161,8 @@ class OBSBuilder(object):
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.') raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
if self.pomdp_r: if self.pomdp_r:
try: try:
light_map = np.zeros(self.obs_shape) light_map = self.curr_lightmaps.get(agent.name, np.zeros(self.obs_shape))
light_map[:] = 0.0
visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False) visible_floor = self.ray_caster[agent.name].visible_entities(self._floortiles, reset_cache=False)
for f in set(visible_floor): for f in set(visible_floor):

View File

@@ -49,7 +49,7 @@ def prepare_plt(df, hue, style, hue_order):
plt.close('all') plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid') sns.set(rc={'text.usetex': False}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style, lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, ) errorbar=('ci', 95), palette=PALETTE, hue_order=hue_order, )
plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
plt.tight_layout() plt.tight_layout()
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}') # lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')

View File

@@ -8,7 +8,7 @@ import numpy as np
from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.algorithms.static.utils import points_to_graph
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule, SpawnAgents
from marl_factory_grid.utils.results import Result, DoneResult from marl_factory_grid.utils.results import Result, DoneResult
from marl_factory_grid.environment.tests import Test from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
@@ -32,18 +32,19 @@ class StepRules:
self.rules.append(item) self.rules.append(item)
return True return True
def do_all_reset(self, state):
for rule in self.rules:
if rule_reset_printline := rule.on_reset(state):
state.print(rule_reset_printline)
return c.VALID
def do_all_init(self, state, lvl_map): def do_all_init(self, state, lvl_map):
for rule in self.rules: for rule in self.rules:
if rule_init_printline := rule.on_init(state, lvl_map): if rule_init_printline := rule.on_init(state, lvl_map):
state.print(rule_init_printline) state.print(rule_init_printline)
return c.VALID return c.VALID
def do_all_reset(self, state):
SpawnAgents().on_reset(state)
for rule in self.rules:
if rule_reset_printline := rule.on_reset(state):
state.print(rule_reset_printline)
return c.VALID
def tick_step_all(self, state): def tick_step_all(self, state):
results = list() results = list()
for rule in self.rules: for rule in self.rules:
@@ -91,6 +92,10 @@ class Gamestate(object):
self._floortile_graph = None self._floortile_graph = None
self.tests = StepTests(*tests) self.tests = StepTests(*tests)
def reset(self):
self.curr_step = 0
self.curr_actions = None
def __getitem__(self, item): def __getitem__(self, item):
return self.entities[item] return self.entities[item]
@@ -201,7 +206,7 @@ class Gamestate(object):
results.extend(on_check_done_result) results.extend(on_check_done_result)
return results return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]: def get_collision_positions(self) -> List[Tuple[(int, int)]]:
""" """
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents, Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision. that were unable to move because their target direction was blocked, also a form of collision.

View File

@@ -12,7 +12,7 @@ from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__': if __name__ == '__main__':
# Render at each step? # Render at each step?
render = False render = True
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False explain_config = False
# Collect statistics? # Collect statistics?
@@ -29,7 +29,7 @@ if __name__ == '__main__':
ce.save_all(run_path / 'all_out.yaml') ce.save_all(run_path / 'all_out.yaml')
# Path to config File # Path to config File
path = Path('marl_factory_grid/configs/narrow_corridor.yaml') path = Path('marl_factory_grid/configs/eight_puzzle.yaml')
# Env Init # Env Init
factory = Factory(path) factory = Factory(path)
@@ -61,6 +61,10 @@ if __name__ == '__main__':
if record: if record:
factory.save_records(run_path / 'test.pb') factory.save_records(run_path / 'test.pb')
if plotting: if plotting:
plot_single_run(run_path) factory.report_possible_colum_keys()
plot_single_run(run_path, column_keys=['Global_DoneAtDestinationReachAll', 'step_reward',
'Agent[Karl-Heinz]_DoneAtDestinationReachAll',
'Agent[Wolfgang]_DoneAtDestinationReachAll',
'Global_DoneAtDestinationReachAll'])
print('Done!!! Goodbye....') print('Done!!! Goodbye....')

View File

@@ -5,7 +5,7 @@ long_description = (this_directory / "README.md").read_text()
setup(name='Marl-Factory-Grid', setup(name='Marl-Factory-Grid',
version='0.1.2', version='0.2.0',
description='A framework to research MARL agents in various setings.', description='A framework to research MARL agents in various setings.',
author='Steffen Illium', author='Steffen Illium',
author_email='steffen.illium@ifi.lmu.de', author_email='steffen.illium@ifi.lmu.de',