mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
Multiple Fixes:
- Config Explainer - Rewards - Destination Reach Condition - Additional Step Callback
This commit is contained in:
parent
0ec260f6a2
commit
803d0dae7f
@ -58,7 +58,7 @@ General:
|
|||||||
individual_rewards: true
|
individual_rewards: true
|
||||||
level_name: large
|
level_name: large
|
||||||
pomdp_r: 3
|
pomdp_r: 3
|
||||||
verbose: false
|
verbose: False
|
||||||
tests: false
|
tests: false
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
|
@ -1,53 +1,89 @@
|
|||||||
|
# Gneral env. settings.
|
||||||
General:
|
General:
|
||||||
|
# Just the best seed.
|
||||||
env_seed: 69
|
env_seed: 69
|
||||||
|
# Each agent receives an inividual Reward.
|
||||||
individual_rewards: true
|
individual_rewards: true
|
||||||
|
# level file to load from .\levels\.
|
||||||
level_name: eight_puzzle
|
level_name: eight_puzzle
|
||||||
|
# Partial Observability. 0 = Full Observation.
|
||||||
pomdp_r: 0
|
pomdp_r: 0
|
||||||
verbose: True
|
# Please do not spam me.
|
||||||
|
verbose: false
|
||||||
|
# Do not touch, WIP
|
||||||
tests: false
|
tests: false
|
||||||
|
|
||||||
|
# RL Surrogates
|
||||||
Agents:
|
Agents:
|
||||||
|
# This defines the name of the agent. UTF-8
|
||||||
Wolfgang:
|
Wolfgang:
|
||||||
|
# Section which defines the availabll Actions per Agent
|
||||||
Actions:
|
Actions:
|
||||||
Noop:
|
# Move4 adds 4 actions [`North`, `East`, `South`, `West`]
|
||||||
fail_reward: -0
|
|
||||||
valid_reward: 0
|
|
||||||
Move4:
|
Move4:
|
||||||
fail_reward: -0.1
|
# Reward specification which differ from the default.
|
||||||
valid_reward: -.01
|
# Agent does a valid move in the environment. He actually moves.
|
||||||
|
valid_reward: -0.1
|
||||||
|
# Agent wants to move, but fails.
|
||||||
|
fail_reward: 0
|
||||||
|
# NOOP aka agent does not do a thing.
|
||||||
|
Noop:
|
||||||
|
# The Agent decides to not do anything. Which is always valid.
|
||||||
|
valid_reward: 0
|
||||||
|
# Does not do anything, just using the same interface.
|
||||||
|
fail_reward: 0
|
||||||
|
# What the agent wants to see.
|
||||||
Observations:
|
Observations:
|
||||||
|
# The agent...
|
||||||
|
# sees other agents, but himself.
|
||||||
- Other
|
- Other
|
||||||
|
# wants to see walls
|
||||||
- Walls
|
- Walls
|
||||||
|
# sees his associated Destination (singular). Use the Plural for `see all destinations`.
|
||||||
- Destination
|
- Destination
|
||||||
Clones:
|
# You want to have 7 clones, also possible to name them by giving names as list.
|
||||||
- Juergen
|
Clones: 7
|
||||||
- Soeren
|
# Agents are blocking their grid position from beeing entered by others.
|
||||||
- Walter
|
|
||||||
- Siggi
|
|
||||||
- Dennis
|
|
||||||
- Karl-Heinz
|
|
||||||
- Kevin
|
|
||||||
is_blocking_pos: true
|
is_blocking_pos: true
|
||||||
|
# Apart from agents, which additional endities do you want to load?
|
||||||
Entities:
|
Entities:
|
||||||
|
# Observable destinations, which can be reached by stepping on the same position. Has additional parameters...
|
||||||
Destinations:
|
Destinations:
|
||||||
# Let them spawn on closed doors and agent positions
|
# Let them spawn on closed doors and agent positions
|
||||||
ignore_blocking: true
|
ignore_blocking: true
|
||||||
# We need a special spawn rule...
|
# For 8-Puzzle, we need a special spawn rule...
|
||||||
spawnrule:
|
spawnrule:
|
||||||
# ...which assigns the destinations per agent
|
# ...which spawn a single position just underneath an associated agent.
|
||||||
SpawnDestinationOnAgent: {}
|
SpawnDestinationOnAgent: {} # There are no parameters, so we state empty kwargs.
|
||||||
|
|
||||||
|
# This section defines which operations are performed beside agent action.
|
||||||
|
# Without this section nothing happens, not even Done-condition checks.
|
||||||
|
# Also, situation based rewards are specidief this way.
|
||||||
Rules:
|
Rules:
|
||||||
# Utilities
|
## Utilities
|
||||||
|
# This rule defines the collision mechanic, introduces a related DoneCondition and lets you specify rewards.
|
||||||
|
# Can be omited/ignored if you do not want to take care of collisions at all.
|
||||||
|
# This does not mean, that agents can not collide, its just ignored.
|
||||||
WatchCollisions:
|
WatchCollisions:
|
||||||
|
reward: 0
|
||||||
done_at_collisions: false
|
done_at_collisions: false
|
||||||
|
|
||||||
# Initial random walk
|
# In 8 Puzzle, do not randomize the start positions, rather move a random agent onto the single free position n-times.
|
||||||
DoRandomInitialSteps:
|
DoRandomInitialSteps:
|
||||||
random_steps: 10
|
# How many times?
|
||||||
|
random_steps: 2
|
||||||
|
|
||||||
# Done Conditions
|
## Done Conditions
|
||||||
DoneAtDestinationReach:
|
# Maximum steps per episode. There is no reward for failing.
|
||||||
condition: simultanious
|
|
||||||
DoneAtMaxStepsReached:
|
DoneAtMaxStepsReached:
|
||||||
max_steps: 500
|
# After how many steps should the episode end?
|
||||||
|
max_steps: 200
|
||||||
|
|
||||||
|
# For 8 Puzzle we need a done condition that checks whether destinations have been reached, so...
|
||||||
|
DoneAtDestinationReach:
|
||||||
|
# On every step, should there be a reward for agets that reach their associated destination? No!
|
||||||
|
dest_reach_reward: 0 # Do not touch. This is usefull in other settings!
|
||||||
|
# Reward should only be given when all destiantions are reached in parallel!
|
||||||
|
condition: "simultanious"
|
||||||
|
# Reward if this is the case. Granted to each agent when all agents are at their target position simultaniously.
|
||||||
|
reward_at_done: 1
|
||||||
|
@ -13,22 +13,20 @@ from marl_factory_grid.environment import constants as c
|
|||||||
class Agent(Entity):
|
class Agent(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_is_paralyzed(self):
|
def var_is_paralyzed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
TODO
|
Check if the Agent is able to move and perform actions. Can be paralized by eg. damage or empty battery.
|
||||||
|
|
||||||
|
:return: Wether the Agent is paralyzed.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
return len(self._paralyzed)
|
return bool(len(self._paralyzed))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def paralyze_reasons(self):
|
def paralyze_reasons(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
TODO
|
Reveals the reasons for the recent paralyzation.
|
||||||
|
|
||||||
|
:return: A list of strings.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
return [x for x in self._paralyzed]
|
return [x for x in self._paralyzed]
|
||||||
|
|
||||||
@ -40,43 +38,36 @@ class Agent(Entity):
|
|||||||
@property
|
@property
|
||||||
def actions(self):
|
def actions(self):
|
||||||
"""
|
"""
|
||||||
TODO
|
Reveals the actions this agent is capable of.
|
||||||
|
|
||||||
|
:return: List of actions.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
return self._actions
|
return self._actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observations(self):
|
def observations(self):
|
||||||
"""
|
"""
|
||||||
TODO
|
Reveals the observations which this agent wants to see.
|
||||||
|
|
||||||
|
:return: List of observations.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
return self._observations
|
return self._observations
|
||||||
|
|
||||||
def step_result(self):
|
|
||||||
"""
|
|
||||||
TODO
|
|
||||||
FIXME THINK ITS LEGACY... Not Used any more
|
|
||||||
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def var_is_blocking_pos(self):
|
def var_is_blocking_pos(self):
|
||||||
return self._is_blocking_pos
|
return self._is_blocking_pos
|
||||||
|
|
||||||
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
|
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
TODO
|
This is the main agent surrogate.
|
||||||
|
Actions given to env.step() are associated with this entity and performed at `on_step`.
|
||||||
|
|
||||||
|
|
||||||
:return:
|
:param kwargs: object
|
||||||
|
:param args: object
|
||||||
|
:param is_blocking_pos: object
|
||||||
|
:param observations: object
|
||||||
|
:param actions: object
|
||||||
"""
|
"""
|
||||||
super(Agent, self).__init__(*args, **kwargs)
|
super(Agent, self).__init__(*args, **kwargs)
|
||||||
self._paralyzed = set()
|
self._paralyzed = set()
|
||||||
@ -86,42 +77,38 @@ class Agent(Entity):
|
|||||||
self._status: Union[Result, None] = None
|
self._status: Union[Result, None] = None
|
||||||
self._is_blocking_pos = is_blocking_pos
|
self._is_blocking_pos = is_blocking_pos
|
||||||
|
|
||||||
def summarize_state(self):
|
def summarize_state(self) -> dict[str]:
|
||||||
"""
|
"""
|
||||||
TODO
|
More or less the result of the last action. Usefull for debugging and used in renderer.
|
||||||
|
|
||||||
|
:return: Last action result
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
state_dict = super().summarize_state()
|
state_dict = super().summarize_state()
|
||||||
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
|
state_dict.update(valid=bool(self.state.validity), action=str(self.state.identifier))
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state: Result) -> bool:
|
||||||
"""
|
"""
|
||||||
TODO
|
Place result in temp agent state.
|
||||||
|
|
||||||
|
:return: Always true
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
self._status = state
|
self._status = state
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
|
|
||||||
def paralyze(self, reason):
|
def paralyze(self, reason):
|
||||||
"""
|
"""
|
||||||
TODO
|
Paralyze an agent. Paralyzed agents are not able to do actions.
|
||||||
|
This is usefull, when battery is empty or agent is damaged.
|
||||||
|
|
||||||
|
:return: Always true
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
self._paralyzed.add(reason)
|
self._paralyzed.add(reason)
|
||||||
return c.VALID
|
return c.VALID
|
||||||
|
|
||||||
def de_paralyze(self, reason) -> bool:
|
def de_paralyze(self, reason) -> bool:
|
||||||
"""
|
"""
|
||||||
TODO
|
De-paralyze an agent, so that he is able to perform actions again.
|
||||||
|
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
@ -19,7 +19,7 @@ class Charge(Action):
|
|||||||
|
|
||||||
def do(self, entity, state) -> Union[None, ActionResult]:
|
def do(self, entity, state) -> Union[None, ActionResult]:
|
||||||
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
|
if charge_pod := h.get_first(state[b.CHARGE_PODS].by_pos(entity.pos)):
|
||||||
valid = h.get_first(charge_pod.charge_battery(entity, state))
|
valid = charge_pod.charge_battery(entity, state)
|
||||||
if valid:
|
if valid:
|
||||||
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
state.print(f'{entity.name} just charged batteries at {charge_pod.name}.')
|
||||||
else:
|
else:
|
||||||
|
@ -53,7 +53,7 @@ class BatteryDecharge(Rule):
|
|||||||
|
|
||||||
for agent in state[c.AGENT]:
|
for agent in state[c.AGENT]:
|
||||||
if isinstance(self.per_action_costs, dict):
|
if isinstance(self.per_action_costs, dict):
|
||||||
energy_consumption = self.per_action_costs[agent.step_result()['action']]
|
energy_consumption = self.per_action_costs[agent.state.identifier]
|
||||||
else:
|
else:
|
||||||
energy_consumption = self.per_action_costs
|
energy_consumption = self.per_action_costs
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
|
||||||
# Destination Env
|
# Destination Env
|
||||||
DESTINATION = 'Destinations'
|
DESTINATION = 'Destinations'
|
||||||
DEST_SYMBOL = 1
|
DEST_SYMBOL = 1
|
||||||
|
REACHED_DEST_SYMBOL = 1
|
||||||
|
|
||||||
MODE_SINGLE = 'SINGLE'
|
|
||||||
MODE_GROUPED = 'GROUPED'
|
MODE_SINGLE = 'SINGLE'
|
||||||
SPAWN_MODES = [MODE_SINGLE, MODE_GROUPED]
|
MODE_GROUPED = 'GROUPED'
|
||||||
|
SPAWN_MODES = [MODE_SINGLE, MODE_GROUPED]
|
||||||
|
|
||||||
REWARD_WAIT_VALID: float = 0.1
|
REWARD_WAIT_VALID: float = 0.1
|
||||||
REWARD_WAIT_FAIL: float = -0.1
|
REWARD_WAIT_FAIL: float = -0.1
|
||||||
|
@ -11,7 +11,7 @@ class Destination(Entity):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
return d.DEST_SYMBOL
|
return d.DEST_SYMBOL if not self.was_reached() else 0
|
||||||
|
|
||||||
def __init__(self, *args, action_counts=0, **kwargs):
|
def __init__(self, *args, action_counts=0, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -33,8 +33,8 @@ class DestinationReachReward(Rule):
|
|||||||
|
|
||||||
def tick_step(self, state) -> List[TickResult]:
|
def tick_step(self, state) -> List[TickResult]:
|
||||||
results = []
|
results = []
|
||||||
reached = False
|
|
||||||
for dest in state[d.DESTINATION]:
|
for dest in state[d.DESTINATION]:
|
||||||
|
reached = False
|
||||||
if dest.has_just_been_reached(state) and not dest.was_reached():
|
if dest.has_just_been_reached(state) and not dest.was_reached():
|
||||||
# Dest has just been reached, some agent needs to stand here
|
# Dest has just been reached, some agent needs to stand here
|
||||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
for agent in state[c.AGENT].by_pos(dest.pos):
|
||||||
@ -67,32 +67,27 @@ class DoneAtDestinationReach(DestinationReachReward):
|
|||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.condition = condition
|
self.condition = condition
|
||||||
self.reward = reward_at_done
|
self.reward_at_done = reward_at_done
|
||||||
assert condition in CONDITIONS
|
assert condition in CONDITIONS
|
||||||
|
|
||||||
def on_check_done(self, state) -> List[DoneResult]:
|
def on_check_done(self, state) -> List[DoneResult]:
|
||||||
if self.condition == ANY:
|
if self.condition == ANY:
|
||||||
if any(x.was_reached() for x in state[d.DESTINATION]):
|
if any(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
|
||||||
elif self.condition == ALL:
|
elif self.condition == ALL:
|
||||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
|
||||||
elif self.condition == SIMULTANEOUS:
|
elif self.condition == SIMULTANEOUS:
|
||||||
if all(x.was_reached() for x in state[d.DESTINATION]):
|
if all(x.was_reached() for x in state[d.DESTINATION]):
|
||||||
return [DoneResult(self.name, validity=c.VALID, reward=self.reward)]
|
return [DoneResult(self.name, validity=c.VALID, reward=self.reward_at_done)]
|
||||||
else:
|
else:
|
||||||
for dest in state[d.DESTINATION]:
|
for dest in state[d.DESTINATION]:
|
||||||
if dest.was_reached():
|
if dest.was_reached():
|
||||||
for agent in state[c.AGENT].by_pos(dest.pos):
|
dest.unmark_as_reached()
|
||||||
if dest.bound_entity:
|
state.print(f'{dest} unmarked as reached, not all targets are reached in parallel.')
|
||||||
if dest.bound_entity == agent:
|
else:
|
||||||
pass
|
pass
|
||||||
else:
|
return [DoneResult(f'all_unmarked_as_reached', validity=c.NOT_VALID)]
|
||||||
dest.unmark_as_reached()
|
|
||||||
return [DoneResult(f'{dest}_unmarked_as_reached',
|
|
||||||
validity=c.NOT_VALID, entity=dest)]
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('Check spelling of Parameter "condition".')
|
raise ValueError('Check spelling of Parameter "condition".')
|
||||||
|
|
||||||
|
@ -151,10 +151,12 @@ class FactoryConfigParser(object):
|
|||||||
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
||||||
try:
|
try:
|
||||||
parsed_actions.extend(class_or_classes)
|
parsed_actions.extend(class_or_classes)
|
||||||
|
for actions_class in class_or_classes:
|
||||||
|
conf_kwargs[actions_class.__name__] = conf_kwargs[action]
|
||||||
except TypeError:
|
except TypeError:
|
||||||
parsed_actions.append(class_or_classes)
|
parsed_actions.append(class_or_classes)
|
||||||
|
|
||||||
parsed_actions = [x(**conf_kwargs.get(x, {})) for x in parsed_actions]
|
parsed_actions = [x(**conf_kwargs.get(x.__name__, {})) for x in parsed_actions]
|
||||||
|
|
||||||
# Observation
|
# Observation
|
||||||
observations = list()
|
observations = list()
|
||||||
|
@ -218,32 +218,6 @@ def is_move(action_name: str):
|
|||||||
"""
|
"""
|
||||||
return action_name in MOVEMAP.keys()
|
return action_name in MOVEMAP.keys()
|
||||||
|
|
||||||
|
|
||||||
def asset_str(agent):
|
|
||||||
"""
|
|
||||||
FIXME @ romue
|
|
||||||
"""
|
|
||||||
# What does this abonimation do?
|
|
||||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
|
||||||
# print('error')
|
|
||||||
if step_result := agent.step_result:
|
|
||||||
action = step_result['action_name']
|
|
||||||
valid = step_result['action_valid']
|
|
||||||
col_names = [x.name for x in step_result['collisions']]
|
|
||||||
if any(c.AGENT in name for name in col_names):
|
|
||||||
return 'agent_collision', 'blank'
|
|
||||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
|
||||||
return c.AGENT, 'invalid'
|
|
||||||
elif valid and not is_move(action):
|
|
||||||
return c.AGENT, 'valid'
|
|
||||||
elif valid and is_move(action):
|
|
||||||
return c.AGENT, 'move'
|
|
||||||
else:
|
|
||||||
return c.AGENT, 'idle'
|
|
||||||
else:
|
|
||||||
return c.AGENT, 'idle'
|
|
||||||
|
|
||||||
|
|
||||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||||
"""
|
"""
|
||||||
Locate an object by name or dotted path.
|
Locate an object by name or dotted path.
|
||||||
|
@ -51,7 +51,7 @@ class EnvMonitor(Wrapper):
|
|||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||||
filepath = Path(filepath or self._filepath)
|
filepath = Path(filepath or self._filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with filepath.open('wb') as f:
|
with filepath.open('wb') as f:
|
||||||
|
@ -25,6 +25,12 @@ class EnvRecorder(Wrapper):
|
|||||||
return self.env.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
|
"""
|
||||||
|
Todo
|
||||||
|
|
||||||
|
:param actions:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||||
if not self.episodes or self._curr_episode in self.episodes:
|
if not self.episodes or self._curr_episode in self.episodes:
|
||||||
summary: dict = self.env.summarize_state()
|
summary: dict = self.env.summarize_state()
|
||||||
|
@ -14,8 +14,9 @@ ENTITIES = 'Objects'
|
|||||||
OBSERVATIONS = 'Observations'
|
OBSERVATIONS = 'Observations'
|
||||||
RULES = 'Rule'
|
RULES = 'Rule'
|
||||||
TESTS = 'Tests'
|
TESTS = 'Tests'
|
||||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
|
||||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
|
||||||
|
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
|
||||||
|
|
||||||
|
|
||||||
class ConfigExplainer:
|
class ConfigExplainer:
|
||||||
@ -32,7 +33,9 @@ class ConfigExplainer:
|
|||||||
|
|
||||||
:param custom_path: Path to your custom module folder.
|
:param custom_path: Path to your custom module folder.
|
||||||
"""
|
"""
|
||||||
self.base_path = Path(__file__).parent.parent.resolve()
|
|
||||||
|
self.base_path = Path(__file__).parent.parent.resolve() /'environment'
|
||||||
|
self.modules_path = Path(__file__).parent.parent.resolve() / 'modules'
|
||||||
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
|
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
|
||||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
|
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
|
||||||
|
|
||||||
@ -41,7 +44,13 @@ class ConfigExplainer:
|
|||||||
"""
|
"""
|
||||||
INTERNAL USE ONLY
|
INTERNAL USE ONLY
|
||||||
"""
|
"""
|
||||||
parameters = inspect.signature(class_to_explain).parameters
|
this_search = class_to_explain
|
||||||
|
parameters = dict(inspect.signature(class_to_explain).parameters)
|
||||||
|
while this_search.__bases__:
|
||||||
|
base_class = this_search.__bases__[0]
|
||||||
|
parameters.update(dict(inspect.signature(base_class).parameters))
|
||||||
|
this_search = base_class
|
||||||
|
|
||||||
explained = {class_to_explain.__name__:
|
explained = {class_to_explain.__name__:
|
||||||
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
||||||
}
|
}
|
||||||
@ -52,8 +61,10 @@ class ConfigExplainer:
|
|||||||
INTERNAL USE ONLY
|
INTERNAL USE ONLY
|
||||||
"""
|
"""
|
||||||
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
||||||
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
module_paths = [x.resolve() for x in self.modules_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||||
found_entities = self._load_and_compare(entities_base_cls, module_paths)
|
base_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||||
|
found_entities = self._load_and_compare(entities_base_cls, base_paths)
|
||||||
|
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||||
if self.custom_path is not None:
|
if self.custom_path is not None:
|
||||||
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
||||||
and '__init__' not in x.name]
|
and '__init__' not in x.name]
|
||||||
@ -91,16 +102,14 @@ class ConfigExplainer:
|
|||||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||||
print(f'See file: {filepath}')
|
print(f'See file: {filepath}')
|
||||||
|
|
||||||
def get_actions(self) -> list[str]:
|
def get_actions(self) -> dict[str]:
|
||||||
"""
|
"""
|
||||||
Retrieve all actions from module folders.
|
Retrieve all actions from module folders.
|
||||||
|
|
||||||
:returns: A list of all available actions.
|
:returns: A list of all available actions.
|
||||||
"""
|
"""
|
||||||
actions = self._get_by_identifier(ACTION)
|
actions = self._get_by_identifier(ACTION)
|
||||||
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
|
actions.update({c.MOVE8: {}, c.MOVE4: {}})
|
||||||
actions = list(actions.keys())
|
|
||||||
actions.extend([c.MOVE8, c.MOVE4])
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def get_all(self) -> dict[str]:
|
def get_all(self) -> dict[str]:
|
||||||
@ -172,13 +181,20 @@ class ConfigExplainer:
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
e = [key]
|
e = [key]
|
||||||
except AttributeError as err:
|
except AttributeError as err:
|
||||||
if self.custom_path is not None:
|
try:
|
||||||
try:
|
e = locate_and_import_class(key, self.modules_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
except TypeError:
|
||||||
except TypeError:
|
e = [key]
|
||||||
e = [key]
|
except AttributeError as err2:
|
||||||
|
if self.custom_path is not None:
|
||||||
|
try:
|
||||||
|
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||||
|
except TypeError:
|
||||||
|
e = [key]
|
||||||
else:
|
else:
|
||||||
raise err
|
print(err.args)
|
||||||
|
print(err2.args)
|
||||||
|
exit(-9999)
|
||||||
names.extend(e)
|
names.extend(e)
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
@ -12,9 +12,9 @@ from marl_factory_grid.utils.tools import ConfigExplainer
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Render at each step?
|
# Render at each step?
|
||||||
render = True
|
render = False
|
||||||
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
|
# Reveal all possible Modules (Entities, Rules, Agents[Actions, Observations], etc.)
|
||||||
explain_config = False
|
explain_config = True
|
||||||
# Collect statistics?
|
# Collect statistics?
|
||||||
monitor = True
|
monitor = True
|
||||||
# Record as Protobuf?
|
# Record as Protobuf?
|
||||||
@ -49,7 +49,7 @@ if __name__ == '__main__':
|
|||||||
action_spaces = factory.action_space
|
action_spaces = factory.action_space
|
||||||
while not done:
|
while not done:
|
||||||
a = [randint(0, x.n - 1) for x in action_spaces]
|
a = [randint(0, x.n - 1) for x in action_spaces]
|
||||||
obs_type, _, _, done, info = factory.step(a)
|
obs_type, _, reward, done, info = factory.step(a)
|
||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
if done:
|
if done:
|
||||||
@ -57,14 +57,11 @@ if __name__ == '__main__':
|
|||||||
break
|
break
|
||||||
|
|
||||||
if monitor:
|
if monitor:
|
||||||
factory.save_run(run_path / 'test_monitor.pkl')
|
factory.save_monitor(run_path / 'test_monitor.pkl')
|
||||||
if record:
|
if record:
|
||||||
factory.save_records(run_path / 'test.pb')
|
factory.save_records(run_path / 'test.pb')
|
||||||
if plotting:
|
if plotting:
|
||||||
factory.report_possible_colum_keys()
|
factory.report_possible_colum_keys()
|
||||||
plot_single_run(run_path, column_keys=['Global_DoneAtDestinationReachAll', 'step_reward',
|
plot_single_run(run_path, column_keys=['step_reward'])
|
||||||
'Agent[Karl-Heinz]_DoneAtDestinationReachAll',
|
|
||||||
'Agent[Wolfgang]_DoneAtDestinationReachAll',
|
|
||||||
'Global_DoneAtDestinationReachAll'])
|
|
||||||
|
|
||||||
print('Done!!! Goodbye....')
|
print('Done!!! Goodbye....')
|
||||||
|
@ -71,8 +71,8 @@ if __name__ == '__main__':
|
|||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
print(f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
||||||
env.save_run(out_path / 'reload_monitor.pick',
|
env.save_monitor(out_path / 'reload_monitor.pick',
|
||||||
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
||||||
if record:
|
if record:
|
||||||
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
|
env.save_records(out_path / 'reload_recorder.pick', save_occupation_map=True)
|
||||||
print('all done')
|
print('all done')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user