Merge branch 'main' into unit_testing

This commit is contained in:
Chanumask
2023-11-28 12:28:20 +01:00
21 changed files with 270 additions and 171 deletions

View File

@ -58,7 +58,7 @@ General:
individual_rewards: true individual_rewards: true
level_name: large level_name: large
pomdp_r: 3 pomdp_r: 3
verbose: false verbose: False
tests: false tests: false
Rules: Rules:

View File

@ -1,60 +1,89 @@
# Gneral env. settings.
General: General:
# Just the best seed.
env_seed: 69 env_seed: 69
# Each agent receives an inividual Reward.
individual_rewards: true individual_rewards: true
# level file to load from .\levels\.
level_name: eight_puzzle level_name: eight_puzzle
# Partial Observability. 0 = Full Observation.
pomdp_r: 0 pomdp_r: 0
verbose: True # Please do not spam me.
verbose: false
# Do not touch, WIP
tests: false tests: false
# RL Surrogates
Agents: Agents:
# This defines the name of the agent. UTF-8
Wolfgang: Wolfgang:
# Section which defines the availabll Actions per Agent
Actions: Actions:
Noop: # Move4 adds 4 actions [`North`, `East`, `South`, `West`]
fail_reward: -0
valid_reward: 0
Move4: Move4:
fail_reward: -0.1 # Reward specification which differ from the default.
valid_reward: -.01 # Agent does a valid move in the environment. He actually moves.
valid_reward: -0.1
# Agent wants to move, but fails.
fail_reward: 0
# NOOP aka agent does not do a thing.
Noop:
# The Agent decides to not do anything. Which is always valid.
valid_reward: 0
# Does not do anything, just using the same interface.
fail_reward: 0
# What the agent wants to see.
Observations: Observations:
# The agent...
# sees other agents, but himself.
- Other - Other
# wants to see walls
- Walls - Walls
# sees his associated Destination (singular). Use the Plural for `see all destinations`.
- Destination - Destination
Clones: # You want to have 7 clones, also possible to name them by giving names as list.
- Juergen Clones: 7
- Soeren # Agents are blocking their grid position from beeing entered by others.
- Walter
- Siggi
- Dennis
- Karl-Heinz
- Kevin
is_blocking_pos: true is_blocking_pos: true
# Apart from agents, which additional endities do you want to load?
Entities: Entities:
# Observable destinations, which can be reached by stepping on the same position. Has additional parameters...
Destinations: Destinations:
# Let them spawn on closed doors and agent positions # Let them spawn on closed doors and agent positions
ignore_blocking: true ignore_blocking: true
# We need a special spawn rule... # For 8-Puzzle, we need a special spawn rule...
spawnrule: spawnrule:
# ...which assigns the destinations per agent # ...which spawn a single position just underneath an associated agent.
SpawnDestinationsPerAgent: SpawnDestinationOnAgent: {} # There are no parameters, so we state empty kwargs.
# we use this parameter
coords_or_quantity:
# to enable and assign special positions per agent
Wolfgang: 1
Karl-Heinz: 1
Kevin: 1
Juergen: 1
Soeren: 1
Walter: 1
Siggi: 1
Dennis: 1
# This section defines which operations are performed beside agent action.
# Without this section nothing happens, not even Done-condition checks.
# Also, situation based rewards are specidief this way.
Rules: Rules:
# Utilities ## Utilities
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
# Can be omited/ignored if you do not want to take care of collisions at all.
# This does not mean, that agents can not collide, its just ignored.
WatchCollisions: WatchCollisions:
reward: 0
done_at_collisions: false done_at_collisions: false
# Done Conditions # In 8 Puzzle, do not randomize the start positions, rather move a random agent onto the single free position n-times.
DoneAtDestinationReach: DoRandomInitialSteps:
condition: simultanious # How many times?
random_steps: 2
## Done Conditions
# Maximum steps per episode. There is no reward for failing.
DoneAtMaxStepsReached: DoneAtMaxStepsReached:
max_steps: 500 # After how many steps should the episode end?
max_steps: 200
# For 8 Puzzle we need a done condition that checks whether destinations have been reached, so...
DoneAtDestinationReach:
# On every step, should there be a reward for agets that reach their associated destination? No!
dest_reach_reward: 0 # Do not touch. This is usefull in other settings!
# Reward should only be given when all destiantions are reached in parallel!
condition: "simultanious"
# Reward if this is the case. Granted to each agent when all agents are at their target position simultaniously.
reward_at_done: 1

View File

@ -7,6 +7,8 @@ from marl_factory_grid.utils.helpers import MOVEMAP
from marl_factory_grid.utils.results import ActionResult from marl_factory_grid.utils.results import ActionResult
TYPE_COLLISION = 'collision'
class Action(abc.ABC): class Action(abc.ABC):
@property @property

View File

@ -13,22 +13,20 @@ from marl_factory_grid.environment import constants as c
class Agent(Entity): class Agent(Entity):
@property @property
def var_is_paralyzed(self): def var_is_paralyzed(self) -> bool:
""" """
TODO Check if the Agent is able to move and perform actions. Can be paralized by eg. damage or empty battery.
:return: Wether the Agent is paralyzed.
:return:
""" """
return len(self._paralyzed) return bool(len(self._paralyzed))
@property @property
def paralyze_reasons(self): def paralyze_reasons(self) -> list[str]:
""" """
TODO Reveals the reasons for the recent paralyzation.
:return: A list of strings.
:return:
""" """
return [x for x in self._paralyzed] return [x for x in self._paralyzed]
@ -40,43 +38,36 @@ class Agent(Entity):
@property @property
def actions(self): def actions(self):
""" """
TODO Reveals the actions this agent is capable of.
:return: List of actions.
:return:
""" """
return self._actions return self._actions
@property @property
def observations(self): def observations(self):
""" """
TODO Reveals the observations which this agent wants to see.
:return: List of observations.
:return:
""" """
return self._observations return self._observations
def step_result(self):
"""
TODO
FIXME THINK ITS LEGACY... Not Used any more
:return:
"""
pass
@property @property
def var_is_blocking_pos(self): def var_is_blocking_pos(self):
return self._is_blocking_pos return self._is_blocking_pos
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
""" """
TODO This is the main agent surrogate.
Actions given to env.step() are associated with this entity and performed at `on_step`.
:return: :param kwargs: object
:param args: object
:param is_blocking_pos: object
:param observations: object
:param actions: object
""" """
super(Agent, self).__init__(*args, **kwargs) super(Agent, self).__init__(*args, **kwargs)
self._paralyzed = set() self._paralyzed = set()
@ -86,42 +77,38 @@ class Agent(Entity):
self._status: Union[Result, None] = None self._status: Union[Result, None] = None
self._is_blocking_pos = is_blocking_pos self._is_blocking_pos = is_blocking_pos
def summarize_state(self): def summarize_state(self) -> dict[str]:
""" """
TODO More or less the result of the last action. Usefull for debugging and used in renderer.
:return: Last action result
:return:
""" """
state_dict = super().summarize_state() state_dict = super().summarize_state()
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier)) state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
return state_dict return state_dict
def set_state(self, state): def set_state(self, state: Result) -> bool:
""" """
TODO Place result in temp agent state.
:return: Always true
:return:
""" """
self._status = state self._status = state
return c.VALID return c.VALID
def paralyze(self, reason): def paralyze(self, reason):
""" """
TODO Paralyze an agent. Paralyzed agents are not able to do actions.
This is usefull, when battery is empty or agent is damaged.
:return: Always true
:return:
""" """
self._paralyzed.add(reason) self._paralyzed.add(reason)
return c.VALID return c.VALID
def de_paralyze(self, reason) -> bool: def de_paralyze(self, reason) -> bool:
""" """
TODO De-paralyze an agent, so that he is able to perform actions again.
:return: :return:
""" """

View File

@ -142,6 +142,7 @@ class Factory(gym.Env):
# All is set up, trigger entity spawn with variable pos # All is set up, trigger entity spawn with variable pos
self.state.rules.do_all_reset(self.state) self.state.rules.do_all_reset(self.state)
self.state.rules.do_all_post_spawn_reset(self.state)
# Build initial observations for all agents # Build initial observations for all agents
self.obs_builder.reset(self.state) self.obs_builder.reset(self.state)
@ -218,8 +219,7 @@ class Factory(gym.Env):
# Combine Info dicts into a global one # Combine Info dicts into a global one
combined_info_dict = defaultdict(lambda: 0.0) combined_info_dict = defaultdict(lambda: 0.0)
for result in chain(tick_results, done_check_results): for result in chain(tick_results, done_check_results):
if not result: assert result, 'Something returned None...'
raise ValueError()
if result.reward is not None: if result.reward is not None:
try: try:
rewards[result.entity.name] += result.reward rewards[result.entity.name] += result.reward

View File

@ -4,15 +4,17 @@ from random import shuffle
from typing import Dict from typing import Dict
from marl_factory_grid.environment.groups.objects import Objects from marl_factory_grid.environment.groups.objects import Objects
from marl_factory_grid.utils.helpers import POS_MASK from marl_factory_grid.utils.helpers import POS_MASK_8, POS_MASK_4
class Entities(Objects): class Entities(Objects):
_entity = Objects _entity = Objects
@staticmethod def neighboring_positions(self, pos):
def neighboring_positions(pos): return [tuple(x) for x in (POS_MASK_8 + pos).reshape(-1, 2) if tuple(x) in self._floor_positions]
return [tuple(x) for x in (POS_MASK + pos).reshape(-1, 2)]
def neighboring_4_positions(self, pos):
return [tuple(x) for x in (POS_MASK_4 + pos) if tuple(x) in self._floor_positions]
def get_entities_near_pos(self, pos): def get_entities_near_pos(self, pos):
return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x] return [y for x in itemgetter(*self.neighboring_positions(pos))(self.pos_dict) for y in x]

View File

@ -1,7 +1,10 @@
import abc import abc
import random
from random import shuffle from random import shuffle
from typing import List, Collection from typing import List, Collection
import numpy as np
from marl_factory_grid.environment import rewards as r, constants as c from marl_factory_grid.environment import rewards as r, constants as c
from marl_factory_grid.environment.entity.agent import Agent from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.utils import helpers as h from marl_factory_grid.utils import helpers as h
@ -37,6 +40,15 @@ class Rule(abc.ABC):
TODO TODO
:return:
"""
return []
def on_reset_post_spawn(self, state) -> List[TickResult]:
"""
TODO
:return: :return:
""" """
return [] return []
@ -230,3 +242,33 @@ class WatchCollisions(Rule):
if inter_entity_collision_detected or collision_in_step: if inter_entity_collision_detected or collision_in_step:
return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)] return [DoneResult(validity=c.VALID, identifier=c.COLLISION, reward=self.reward_at_done)]
return [] return []
class DoRandomInitialSteps(Rule):
def __init__(self, random_steps: 10):
"""
Special rule which spawns destinations, that are bound to a single agent a fixed set of positions.
Useful for introducing specialists, etc. ..
!!! This rule does not introduce any reward or done condition.
:param random_steps: Number of random steps agents perform in an environment.
Useful in the `N-Puzzle` configuration.
"""
super().__init__()
self.random_steps = random_steps
def on_reset_post_spawn(self, state):
state.print("Random Initial Steps initiated....")
for _ in range(self.random_steps):
# Find free positions
free_pos = state.random_free_position
neighbor_positions = state.entities.neighboring_4_positions(free_pos)
random.shuffle(neighbor_positions)
chosen_agent = h.get_first(state[c.AGENT].by_pos(neighbor_positions.pop()))
assert isinstance(chosen_agent, Agent)
valid = chosen_agent.move(free_pos, state)
valid_str = " not" if not valid else ""
state.print(f"Move {chosen_agent.name} from {chosen_agent.last_pos} "
f"to {chosen_agent.pos} was{valid_str} valid.")
pass

View File

@ -19,7 +19,7 @@ class Charge(Action):
def do(self, entity, state) -> Union[None, ActionResult]: def do(self, entity, state) -> Union[None, ActionResult]:
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)): if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
valid = h.get_first(charge_pod.charge_battery(entity, state)) valid = charge_pod.charge_battery(entity, state)
if valid: if valid:
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.') state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
else: else:

View File

@ -53,7 +53,7 @@ class BatteryDecharge(Rule):
for agent in state[c.AGENT]: for agent in state[c.AGENT]:
if isinstance(self.per_action_costs, dict): if isinstance(self.per_action_costs, dict):
energy_consumption = self.per_action_costs[agent.step_result()['action']] energy_consumption = self.per_action_costs[agent.state.identifier]
else: else:
energy_consumption = self.per_action_costs energy_consumption = self.per_action_costs

View File

@ -1,11 +1,13 @@
# Destination Env # Destination Env
DESTINATION = 'Destinations' DESTINATION = 'Destinations'
DEST_SYMBOL = 1 DEST_SYMBOL = 1
REACHED_DEST_SYMBOL = 1
MODE_SINGLE = 'SINGLE'
MODE_GROUPED = 'GROUPED' MODE_SINGLE = 'SINGLE'
SPAWN_MODES = [MODE_SINGLE, MODE_GROUPED] MODE_GROUPED = 'GROUPED'
SPAWN_MODES = [MODE_SINGLE, MODE_GROUPED]
REWARD_WAIT_VALID: float = 0.1 REWARD_WAIT_VALID: float = 0.1
REWARD_WAIT_FAIL: float = -0.1 REWARD_WAIT_FAIL: float = -0.1

View File

@ -11,7 +11,7 @@ class Destination(Entity):
@property @property
def encoding(self): def encoding(self):
return d.DEST_SYMBOL return d.DEST_SYMBOL if not self.was_reached() else 0
def __init__(self, *args, action_counts=0, **kwargs): def __init__(self, *args, action_counts=0, **kwargs):
""" """

View File

@ -32,8 +32,8 @@ class DestinationReachReward(Rule):
def tick_step(self, state) -> List[TickResult]: def tick_step(self, state) -> List[TickResult]:
results = [] results = []
reached = False
for dest in state[d.DESTINATION]: for dest in state[d.DESTINATION]:
reached = False
if dest.has_just_been_reached(state) and not dest.was_reached(): if dest.has_just_been_reached(state) and not dest.was_reached():
# Dest has just been reached, some agent needs to stand here # Dest has just been reached, some agent needs to stand here
for agent in state[c.AGENT].by_pos(dest.pos): for agent in state[c.AGENT].by_pos(dest.pos):
@ -66,32 +66,27 @@ class DoneAtDestinationReach(DestinationReachReward):
""" """
super().__init__(**kwargs) super().__init__(**kwargs)
self.condition = condition self.condition = condition
self.reward = reward_at_done self.reward_at_done = reward_at_done
assert condition in CONDITIONS assert condition in CONDITIONS
def on_check_done(self, state) -> List[DoneResult]: def on_check_done(self, state) -> List[DoneResult]:
if self.condition == ANY: if self.condition == ANY:
if any(x.was_reached() for x in state[d.DESTINATION]): if any(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
elif self.condition == ALL: elif self.condition == ALL:
if all(x.was_reached() for x in state[d.DESTINATION]): if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
elif self.condition == SIMULTANEOUS: elif self.condition == SIMULTANEOUS:
if all(x.was_reached() for x in state[d.DESTINATION]): if all(x.was_reached() for x in state[d.DESTINATION]):
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)] return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
else: else:
for dest in state[d.DESTINATION]: for dest in state[d.DESTINATION]:
if dest.was_reached(): if dest.was_reached():
for agent in state[c.AGENT].by_pos(dest.pos): dest.unmark_as_reached()
if dest.bound_entity: state.print(f'{dest} unmarked as reached, not all targets are reached in parallel.')
if dest.bound_entity == agent: else:
pass pass
else: return [DoneResult(f'all_unmarked_as_reached', validity=c.NOT_VALID)]
dest.unmark_as_reached()
return [DoneResult(f'{dest}_unmarked_as_reached',
validity=c.NOT_VALID, entity=dest)]
else:
pass
else: else:
raise ValueError('Check spelling of Parameter "condition".') raise ValueError('Check spelling of Parameter "condition".')
@ -104,10 +99,10 @@ class SpawnDestinationsPerAgent(Rule):
!!! This rule does not introduce any reward or done condition. !!! This rule does not introduce any reward or done condition.
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible :param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]} destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
""" """
super(Rule, self).__init__() super().__init__()
self.per_agent_positions = dict() self.per_agent_positions = dict()
for agent_name, value in coords_or_quantity.items(): for agent_name, value in coords_or_quantity.items():
if isinstance(value, int): if isinstance(value, int):
@ -142,3 +137,25 @@ class SpawnDestinationsPerAgent(Rule):
continue continue
state[d.DESTINATION].add_item(destination) state[d.DESTINATION].add_item(destination)
pass pass
class SpawnDestinationOnAgent(Rule):
def __init__(self):
"""
Special rule which spawns a single destination bound to a single agent just `below` him. Usefull for
the `N-Puzzle` configurations.
!!! This rule does not introduce any reward or done condition.
:param coords_or_quantity: Please provide a dictionary with agent names as keys; and a list of possible
destination coords as value. Example: {Wolfgang: [(0, 0), (1, 1), ...]}
"""
super().__init__()
def on_reset(self, state: Gamestate):
state.print("Spawn Desitnations")
for agent in state[c.AGENT]:
destination = Destination(agent.pos, bind_to=agent)
state[d.DESTINATION].add_item(destination)
assert len(state[d.DESTINATION].by_pos(agent.pos)) == 1
pass

View File

@ -153,10 +153,12 @@ class FactoryConfigParser(object):
class_or_classes = locate_and_import_class(action, self.custom_modules_path) class_or_classes = locate_and_import_class(action, self.custom_modules_path)
try: try:
parsed_actions.extend(class_or_classes) parsed_actions.extend(class_or_classes)
for actions_class in class_or_classes:
conf_kwargs[actions_class.__name__] = conf_kwargs[action]
except TypeError: except TypeError:
parsed_actions.append(class_or_classes) parsed_actions.append(class_or_classes)
parsed_actions = [x(**conf_kwargs.get(x, {})) for x in parsed_actions] parsed_actions = [x(**conf_kwargs.get(x.__name__, {})) for x in parsed_actions]
# Observation # Observation
observations = list() observations = list()

View File

@ -27,9 +27,11 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignore
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation', 'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode'] 'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]], POS_MASK_8 = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]], [[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]]) [[-1, 1], [0, 1], [1, 1]]])
POS_MASK_4 = np.asarray([[0, -1], [-1, 0], [1, 0], [-1, 1], [0, 1], [1, 1]])
MOVEMAP = defaultdict(lambda: (0, 0), MOVEMAP = defaultdict(lambda: (0, 0),
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1), {c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
@ -216,32 +218,6 @@ def is_move(action_name: str):
""" """
return action_name in MOVEMAP.keys() return action_name in MOVEMAP.keys()
def asset_str(agent):
"""
FIXME @ romue
"""
# What does this abonimation do?
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
# print('error')
if step_result := agent.step_result:
action = step_result['action_name']
valid = step_result['action_valid']
col_names = [x.name for x in step_result['collisions']]
if any(c.AGENT in name for name in col_names):
return 'agent_collision', 'blank'
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
return c.AGENT, 'invalid'
elif valid and not is_move(action):
return c.AGENT, 'valid'
elif valid and is_move(action):
return c.AGENT, 'move'
else:
return c.AGENT, 'idle'
else:
return c.AGENT, 'idle'
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
""" """
Locate an object by name or dotted path. Locate an object by name or dotted path.

View File

@ -51,7 +51,7 @@ class EnvMonitor(Wrapper):
pass pass
return return
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None): def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
filepath = Path(filepath or self._filepath) filepath = Path(filepath or self._filepath)
filepath.parent.mkdir(exist_ok=True, parents=True) filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f: with filepath.open('wb') as f:

View File

@ -25,6 +25,12 @@ class EnvRecorder(Wrapper):
return self.env.reset() return self.env.reset()
def step(self, actions): def step(self, actions):
"""
Todo
:param actions:
:return:
"""
obs_type, obs, reward, done, info = self.env.step(actions) obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes: if not self.episodes or self._curr_episode in self.episodes:
summary: dict = self.env.summarize_state() summary: dict = self.env.summarize_state()

View File

@ -2,9 +2,11 @@ from typing import Union
from dataclasses import dataclass from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object from marl_factory_grid.environment.entity.object import Object
import marl_factory_grid.environment.constants as c
TYPE_VALUE = 'value' TYPE_VALUE = 'value'
TYPE_REWARD = 'reward' TYPE_REWARD = 'reward'
TYPES = [TYPE_VALUE, TYPE_REWARD] TYPES = [TYPE_VALUE, TYPE_REWARD]
@ -32,8 +34,9 @@ class Result:
""" """
identifier: str identifier: str
validity: bool validity: bool
reward: Union[float, None] = None reward: float | None = None
value: Union[float, None] = None value: float | None = None
collision: bool | None = None
entity: Object = None entity: Object = None
def get_infos(self): def get_infos(self):
@ -68,8 +71,17 @@ class ActionResult(Result):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.action_introduced_collision = action_introduced_collision self.action_introduced_collision = action_introduced_collision
pass def __repr__(self):
sr = super().__repr__()
return sr + f" | {c.COLLISION}" if self.action_introduced_collision is not None else ""
def get_infos(self):
base_infos = super().get_infos()
if self.action_introduced_collision:
i = InfoObject(identifier=f'{self.entity.name}_{c.COLLISION}', val_type=TYPE_VALUE, value=1)
return base_infos + [i]
else:
return base_infos
@dataclass @dataclass
class DoneResult(Result): class DoneResult(Result):

View File

@ -49,6 +49,12 @@ class StepRules:
state.print(rule_reset_printline) state.print(rule_reset_printline)
return c.VALID return c.VALID
def do_all_post_spawn_reset(self, state):
for rule in self.rules:
if rule_reset_printline := rule.on_reset_post_spawn(state):
state.print(rule_reset_printline)
return c.VALID
def tick_step_all(self, state): def tick_step_all(self, state):
results = list() results = list()
for rule in self.rules: for rule in self.rules:

View File

@ -14,8 +14,9 @@ ENTITIES = 'Objects'
OBSERVATIONS = 'Observations' OBSERVATIONS = 'Observations'
RULES = 'Rule' RULES = 'Rule'
TESTS = 'Tests' TESTS = 'Tests'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ] 'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
class ConfigExplainer: class ConfigExplainer:
@ -32,7 +33,9 @@ class ConfigExplainer:
:param custom_path: Path to your custom module folder. :param custom_path: Path to your custom module folder.
""" """
self.base_path = Path(__file__).parent.parent.resolve()
self.base_path = Path(__file__).parent.parent.resolve() /'environment'
self.modules_path = Path(__file__).parent.parent.resolve() / 'modules'
self.custom_path = Path(custom_path) if custom_path is not None else custom_path self.custom_path = Path(custom_path) if custom_path is not None else custom_path
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS] self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
@ -41,9 +44,16 @@ class ConfigExplainer:
""" """
INTERNAL USE ONLY INTERNAL USE ONLY
""" """
parameters = inspect.signature(class_to_explain).parameters this_search = class_to_explain
parameters = dict(inspect.signature(class_to_explain).parameters)
while this_search.__bases__:
base_class = this_search.__bases__[0]
parameters.update(dict(inspect.signature(base_class).parameters))
this_search = base_class
explained = {class_to_explain.__name__: explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED} {key: val.default if val.default != inspect._empty else '!' for key, val in parameters.items()
if key not in EXCLUDED}
} }
return explained return explained
@ -52,8 +62,10 @@ class ConfigExplainer:
INTERNAL USE ONLY INTERNAL USE ONLY
""" """
entities_base_cls = locate_and_import_class(identifier, self.base_path) entities_base_cls = locate_and_import_class(identifier, self.base_path)
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name] module_paths = [x.resolve() for x in self.modules_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, module_paths) base_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, base_paths)
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
if self.custom_path is not None: if self.custom_path is not None:
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file() module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
and '__init__' not in x.name] and '__init__' not in x.name]
@ -91,16 +103,14 @@ class ConfigExplainer:
print(f'Example config {"for " + tag + " " if tag else " "}dumped') print(f'Example config {"for " + tag + " " if tag else " "}dumped')
print(f'See file: {filepath}') print(f'See file: {filepath}')
def get_actions(self) -> list[str]: def get_actions(self) -> dict[str]:
""" """
Retrieve all actions from module folders. Retrieve all actions from module folders.
:returns: A list of all available actions. :returns: A list of all available actions.
""" """
actions = self._get_by_identifier(ACTION) actions = self._get_by_identifier(ACTION)
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.' actions.update({c.MOVE8: {}, c.MOVE4: {}})
actions = list(actions.keys())
actions.extend([c.MOVE8, c.MOVE4])
return actions return actions
def get_all(self) -> dict[str]: def get_all(self) -> dict[str]:
@ -125,6 +135,8 @@ class ConfigExplainer:
:returns: A list of all available entities. :returns: A list of all available entities.
""" """
entities = self._get_by_identifier(ENTITIES) entities = self._get_by_identifier(ENTITIES)
for key in ['Combined', 'Agents', 'Inventory']:
del entities[key]
return entities return entities
@staticmethod @staticmethod
@ -172,13 +184,20 @@ class ConfigExplainer:
except TypeError: except TypeError:
e = [key] e = [key]
except AttributeError as err: except AttributeError as err:
if self.custom_path is not None: try:
try: e = locate_and_import_class(key, self.modules_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs except TypeError:
except TypeError: e = [key]
e = [key] except AttributeError as err2:
if self.custom_path is not None:
try:
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
except TypeError:
e = [key]
else: else:
raise err print(err.args)
print(err2.args)
exit(-9999)
names.extend(e) names.extend(e)
return names return names

View File

@ -12,9 +12,9 @@ from marl_factory_grid.utils.tools import ConfigExplainer
if __name__ == '__main__': if __name__ == '__main__':
# Render at each step? # Render at each step?
render = True render = False
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.) # Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
explain_config = False explain_config = True
# Collect statistics? # Collect statistics?
monitor = True monitor = True
# Record as Protobuf? # Record as Protobuf?
@ -26,10 +26,10 @@ if __name__ == '__main__':
if explain_config: if explain_config:
ce = ConfigExplainer() ce = ConfigExplainer()
ce.save_all(run_path / 'all_out.yaml') ce.save_all(run_path / 'all_available_configs.yaml')
# Path to config File # Path to config File
path = Path('marl_factory_grid/configs/clean_and_bring.yaml') path = Path('marl_factory_grid/configs/eight_puzzle.yaml')
# Env Init # Env Init
factory = Factory(path) factory = Factory(path)
@ -49,7 +49,7 @@ if __name__ == '__main__':
action_spaces = factory.action_space action_spaces = factory.action_space
while not done: while not done:
a = [randint(0, x.n - 1) for x in action_spaces] a = [randint(0, x.n - 1) for x in action_spaces]
obs_type, _, _, done, info = factory.step(a) obs_type, _, reward, done, info = factory.step(a)
if render: if render:
factory.render() factory.render()
if done: if done:
@ -57,14 +57,11 @@ if __name__ == '__main__':
break break
if monitor: if monitor:
factory.save_run(run_path / 'test_monitor.pkl') factory.save_monitor(run_path / 'test_monitor.pkl')
if record: if record:
factory.save_records(run_path / 'test.pb') factory.save_records(run_path / 'test.pb')
if plotting: if plotting:
factory.report_possible_colum_keys() factory.report_possible_colum_keys()
plot_single_run(run_path, column_keys=['Global_DoneAtDestinationReachAll', 'step_reward', plot_single_run(run_path, column_keys=['step_reward'])
'Agent[Karl-Heinz]_DoneAtDestinationReachAll',
'Agent[Wolfgang]_DoneAtDestinationReachAll',
'Global_DoneAtDestinationReachAll'])
print('Done!!! Goodbye....') print('Done!!! Goodbye....')

View File

@ -71,8 +71,8 @@ if __name__ == '__main__':
if done_bool: if done_bool:
break break
print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}') print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
env.save_run(out_path / 'reload_monitor.pick', env.save_monitor(out_path / 'reload_monitor.pick',
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail']) auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
if record: if record:
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True) env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
print('all done') print('all done')