mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-11-02 13:37:27 +01:00
Resolved some warnings and style issues
This commit is contained in:
@@ -58,7 +58,10 @@ class FactoryConfigParser(object):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
try:
|
||||
return self.config[item]
|
||||
except KeyError:
|
||||
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
|
||||
|
||||
def load_entities(self):
|
||||
entity_classes = dict()
|
||||
@@ -161,7 +164,6 @@ class FactoryConfigParser(object):
|
||||
|
||||
def _load_smth(self, config, class_obj):
|
||||
rules = list()
|
||||
rules_names = list()
|
||||
for rule in config:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
|
||||
@@ -61,8 +61,8 @@ class ObservationTranslator:
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
:param placeholder_fill_value: Currently, not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N'
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
|
||||
@@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
|
||||
|
||||
|
||||
class EnvMonitor(Wrapper):
|
||||
@@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dict = dict()
|
||||
|
||||
|
||||
def step(self, action):
|
||||
obs_type, obs, reward, done, info = self.env.step(action)
|
||||
self._read_info(info)
|
||||
|
||||
@@ -2,11 +2,9 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import yaml
|
||||
from gymnasium import Wrapper
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gymnasium import Wrapper
|
||||
|
||||
|
||||
class EnvRecorder(Wrapper):
|
||||
@@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._curr_episode,
|
||||
'metadata':dict(
|
||||
'metadata': dict(
|
||||
level_name=self.env.params['General']['level_name'],
|
||||
verbose=False,
|
||||
n_agents=len(self.env.params['Agents']),
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.object import _Object
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.utility_classes import Floor
|
||||
from marl_factory_grid.utils.ray_caster import RayCaster
|
||||
@@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
default_obs = [c.WALLS, c.OTHERS]
|
||||
|
||||
@@ -128,7 +127,7 @@ class OBSBuilder(object):
|
||||
f'{re.escape("[")}(.*){re.escape("]")}'
|
||||
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
|
||||
name = next((key for key, val in self.all_obs.items()
|
||||
if pattern.search(str(val)) and isinstance(val, _Object)), None)
|
||||
if pattern.search(str(val)) and isinstance(val, Object)), None)
|
||||
e = self.all_obs[name]
|
||||
except KeyError:
|
||||
try:
|
||||
@@ -181,11 +180,11 @@ class OBSBuilder(object):
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
'''
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
'''
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
obs_layers = []
|
||||
|
||||
@@ -7,50 +7,11 @@ from typing import Union, List
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting import prepare_plot
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
MODEL_MAP = None
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob('*monitor*.pick'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
|
||||
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
48
marl_factory_grid/utils/plotting/plot_single_runs.py
Normal file
48
marl_factory_grid/utils/plotting/plot_single_runs.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str ='monitor', file_ext: str ='pkl'):
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
monitor_file = run_path
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
else:
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
# roll_n = 50
|
||||
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
@@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
fig = plt.figure(figsize=(10, 11))
|
||||
_ = plt.figure(figsize=(10, 11))
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
|
||||
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
|
||||
@@ -19,7 +19,7 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
@@ -53,9 +53,9 @@ class RayCaster:
|
||||
diag_hits = all([
|
||||
self.ray_block_cache(
|
||||
key,
|
||||
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
|
||||
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
|
||||
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
|
||||
for key in ((x, y - cy), (x - cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
@@ -77,8 +77,8 @@ class RayCaster:
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
|
||||
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
@@ -18,12 +21,13 @@ class Result:
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: None = None
|
||||
entity: Object = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in types
|
||||
# Return multiple Info Dicts
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in TYPES
|
||||
if self.__getattribute__(t) is not None]
|
||||
|
||||
def __repr__(self):
|
||||
@@ -31,7 +35,7 @@ class Result:
|
||||
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
|
||||
value = f" | Value: {self.value}" if self.value is not None else ""
|
||||
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from itertools import islice
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.utils.results import Result, DoneResult
|
||||
|
||||
|
||||
class StepRules:
|
||||
@@ -83,13 +84,51 @@ class Gamestate(object):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
@property
|
||||
def random_free_position(self):
|
||||
def random_free_position(self) -> (int, int):
|
||||
"""
|
||||
Returns a single **free** position (x, y), which is **free** for spawning or walking.
|
||||
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
|
||||
|
||||
:return: Single **free** position.
|
||||
"""
|
||||
return self.get_n_random_free_positions(1)[0]
|
||||
|
||||
def get_n_random_free_positions(self, n):
|
||||
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
|
||||
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
|
||||
|
||||
:return: List of n **free** position.
|
||||
"""
|
||||
return list(islice(self.entities.free_positions_generator, n))
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
@property
|
||||
def random_position(self) -> (int, int):
|
||||
"""
|
||||
Returns a single available position (x, y), ignores all entity attributes.
|
||||
|
||||
:return: Single random position.
|
||||
"""
|
||||
return self.get_n_random_positions(1)[0]
|
||||
|
||||
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
|
||||
|
||||
:return: List of n random positions.
|
||||
"""
|
||||
return list(islice(self.entities.floorlist, n))
|
||||
|
||||
def tick(self, actions) -> list[Result]:
|
||||
"""
|
||||
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
|
||||
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
|
||||
- agent tick: Agents do their actions.
|
||||
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
|
||||
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
|
||||
|
||||
:return: List of *Result*-objects.
|
||||
"""
|
||||
results = list()
|
||||
self.curr_step += 1
|
||||
|
||||
@@ -112,11 +151,23 @@ class Gamestate(object):
|
||||
|
||||
return results
|
||||
|
||||
def print(self, string):
|
||||
def print(self, string) -> None:
|
||||
"""
|
||||
When *verbose* is active, print stuff.
|
||||
|
||||
:param string: *String* to print.
|
||||
:type string: str
|
||||
:return: Nothing
|
||||
"""
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
def check_done(self):
|
||||
def check_done(self) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if on_check_done_result := rule.on_check_done(self):
|
||||
@@ -124,20 +175,44 @@ class Gamestate(object):
|
||||
return results
|
||||
|
||||
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
|
||||
"""
|
||||
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
|
||||
that were unable to move because their target direction was blocked, also a form of collision.
|
||||
|
||||
:return: List of positions.
|
||||
"""
|
||||
positions = [pos for pos, entities in self.entities.pos_dict.items() if
|
||||
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
|
||||
]
|
||||
return positions
|
||||
|
||||
def check_move_validity(self, moving_entity, position):
|
||||
if moving_entity.pos != position and not any(
|
||||
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
|
||||
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
|
||||
"""
|
||||
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
|
||||
when position is allready occupied.
|
||||
|
||||
def check_pos_validity(self, position):
|
||||
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
:param moving_entity: Entity
|
||||
:param target_position: pos
|
||||
:return: Safe to move to
|
||||
"""
|
||||
|
||||
is_not_blocked = self.check_pos_validity(target_position)
|
||||
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
|
||||
|
||||
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def check_pos_validity(self, pos: (int, int)) -> bool:
|
||||
"""
|
||||
Check if *pos* is a valid position to move or spawn to.
|
||||
|
||||
:param pos: position to check
|
||||
:return: Wheter pos is a valid target.
|
||||
"""
|
||||
|
||||
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
@@ -28,7 +28,9 @@ class ConfigExplainer:
|
||||
|
||||
def explain_module(self, class_to_explain):
|
||||
parameters = inspect.signature(class_to_explain).parameters
|
||||
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
|
||||
explained = {class_to_explain.__name__:
|
||||
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
||||
}
|
||||
return explained
|
||||
|
||||
def _load_and_compare(self, compare_class, paths):
|
||||
|
||||
Reference in New Issue
Block a user