Resolved some warnings and style issues

This commit is contained in:
Steffen Illium
2023-11-10 09:29:54 +01:00
parent a9462a8b6f
commit 6711a0976b
64 changed files with 331 additions and 361 deletions

View File

@@ -58,7 +58,10 @@ class FactoryConfigParser(object):
return str(self.config)
def __getitem__(self, item):
return self.config[item]
try:
return self.config[item]
except KeyError:
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
entity_classes = dict()
@@ -161,7 +164,6 @@ class FactoryConfigParser(object):
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in config:
e1 = e2 = e3 = None
try:

View File

@@ -61,8 +61,8 @@ class ObservationTranslator:
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
type per_agent_named_obs_spaces: Dict[str, dict]
:param placeholder_fill_value: Currently not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N')
:param placeholder_fill_value: Currently, not fully implemented!!!
:type placeholder_fill_value: Union[int, str] = 'N'
"""
if isinstance(placeholder_fill_value, str):

View File

@@ -9,7 +9,7 @@ from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
import pandas as pd
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
from marl_factory_grid.utils.plotting.plot_single_runs import plot_single_run
class EnvMonitor(Wrapper):
@@ -22,7 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)

View File

@@ -2,11 +2,9 @@ from os import PathLike
from pathlib import Path
from typing import Union, List
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
from gymnasium import Wrapper
class EnvRecorder(Wrapper):
@@ -106,7 +104,7 @@ class EnvRecorder(Wrapper):
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
'metadata':dict(
'metadata': dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),

View File

@@ -5,7 +5,7 @@ from typing import Dict, List
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.object import _Object
from marl_factory_grid.environment.entity.object import Object
from marl_factory_grid.environment.groups.utils import Combined
from marl_factory_grid.utils.utility_classes import Floor
from marl_factory_grid.utils.ray_caster import RayCaster
@@ -13,7 +13,6 @@ from marl_factory_grid.utils.states import Gamestate
from marl_factory_grid.utils import helpers as h
class OBSBuilder(object):
default_obs = [c.WALLS, c.OTHERS]
@@ -128,7 +127,7 @@ class OBSBuilder(object):
f'{re.escape("[")}(.*){re.escape("]")}'
f'{re.escape("(")}{re.escape(agent.name)}{re.escape(")")}')
name = next((key for key, val in self.all_obs.items()
if pattern.search(str(val)) and isinstance(val, _Object)), None)
if pattern.search(str(val)) and isinstance(val, Object)), None)
e = self.all_obs[name]
except KeyError:
try:
@@ -181,11 +180,11 @@ class OBSBuilder(object):
return obs, self.obs_layers[agent.name]
def _sort_and_name_observation_conf(self, agent):
'''
"""
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
'''
"""
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
obs_layers = []

View File

@@ -7,50 +7,11 @@ from typing import Union, List
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting import prepare_plot
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
MODEL_MAP = None
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob('*monitor*.pick'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
roll_n = 50
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
run_path = Path(run_path)
df_list = list()

View File

@@ -0,0 +1,48 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str ='monitor', file_ext: str ='pkl'):
run_path = Path(run_path)
df_list = list()
if run_path.is_dir():
monitor_file = next(run_path.glob(f'*{file_key}*.{file_ext}'))
elif run_path.exists() and run_path.is_file():
monitor_file = run_path
else:
raise ValueError
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
if column_keys is not None:
columns = [col for col in column_keys if col in df.columns]
else:
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
# roll_n = 50
# non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = df[columns + ['Episode']].reset_index().melt(
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
)
if df_melted['Episode'].max() > 800:
skip_n = round(df_melted['Episode'].max() * 0.02)
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
print('Plotting done.')

View File

@@ -60,7 +60,7 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
fig = plt.figure(figsize=(10, 11))
_ = plt.figure(figsize=(10, 11))
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order, legend=False)
# plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)

View File

@@ -19,7 +19,7 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
north = np.array([0, -1])*self.pomdp_r
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
[[math.cos(theta), -math.sin(theta)],
@@ -53,9 +53,9 @@ class RayCaster:
diag_hits = all([
self.ray_block_cache(
key,
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light))
for key in ((x, y-cy), (x-cx, y))
# lambda: all(False for e in pos_dict[key] if not e.var_is_blocking_light)
lambda: any(True for e in pos_dict[key] if e.var_is_blocking_light))
for key in ((x, y - cy), (x - cx, y))
]) if (cx != 0 and cy != 0) else False
visible += entities_hit if not diag_hits else []
@@ -77,8 +77,8 @@ class RayCaster:
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r]))
outline += list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
return outline
@staticmethod

View File

@@ -1,9 +1,12 @@
from typing import Union
from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
types = [TYPE_VALUE, TYPE_REWARD]
TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
@@ -18,12 +21,13 @@ class Result:
validity: bool
reward: Union[float, None] = None
value: Union[float, None] = None
entity: None = None
entity: Object = None
def get_infos(self):
n = self.entity.name if self.entity is not None else "Global"
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
val_type=t, value=self.__getattribute__(t)) for t in types
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
val_type=t, value=self.__getattribute__(t)) for t in TYPES
if self.__getattribute__(t) is not None]
def __repr__(self):
@@ -31,7 +35,7 @@ class Result:
reward = f" | Reward: {self.reward}" if self.reward is not None else ""
value = f" | Value: {self.value}" if self.value is not None else ""
entity = f" | by: {self.entity.name}" if self.entity is not None else ""
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value})'
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass

View File

@@ -1,11 +1,12 @@
from itertools import islice
from typing import List, Dict, Tuple
from typing import List, Tuple
import numpy as np
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.entity.entity import Entity
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import Result
from marl_factory_grid.utils.results import Result, DoneResult
class StepRules:
@@ -83,13 +84,51 @@ class Gamestate(object):
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
@property
def random_free_position(self):
def random_free_position(self) -> (int, int):
"""
Returns a single **free** position (x, y), which is **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: Single **free** position.
"""
return self.get_n_random_free_positions(1)[0]
def get_n_random_free_positions(self, n):
def get_n_random_free_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* **free** positions [(x, y), ... ], which are **free** for spawning or walking.
No Entity at this position posses *var_is_blocking_pos* or *var_can_collide*.
:return: List of n **free** position.
"""
return list(islice(self.entities.free_positions_generator, n))
def tick(self, actions) -> List[Result]:
@property
def random_position(self) -> (int, int):
"""
Returns a single available position (x, y), ignores all entity attributes.
:return: Single random position.
"""
return self.get_n_random_positions(1)[0]
def get_n_random_positions(self, n) -> list[tuple[int, int]]:
"""
Returns a list of *n* available positions [(x, y), ... ], ignores all entity attributes.
:return: List of n random positions.
"""
return list(islice(self.entities.floorlist, n))
def tick(self, actions) -> list[Result]:
"""
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
- agent tick: Agents do their actions.
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
- tick_post_step_all: Things to do at the very end of each step. Counting, Reward calculations etc...
:return: List of *Result*-objects.
"""
results = list()
self.curr_step += 1
@@ -112,11 +151,23 @@ class Gamestate(object):
return results
def print(self, string):
def print(self, string) -> None:
"""
When *verbose* is active, print stuff.
:param string: *String* to print.
:type string: str
:return: Nothing
"""
if self.verbose:
print(string)
def check_done(self):
def check_done(self) -> List[DoneResult]:
"""
Iterate all **Rules** that override tehe *on_ckeck_done* hook.
:return: List of Results
"""
results = list()
for rule in self.rules:
if on_check_done_result := rule.on_check_done(self):
@@ -124,20 +175,44 @@ class Gamestate(object):
return results
def get_all_pos_with_collisions(self) -> List[Tuple[(int, int)]]:
positions = [pos for pos, entities in self.entities.pos_dict.items() if len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)]
"""
Returns a list positions [(x, y), ... ] on which collisions occur. This does not include agents,
that were unable to move because their target direction was blocked, also a form of collision.
:return: List of positions.
"""
positions = [pos for pos, entities in self.entities.pos_dict.items() if
len(entities) >= 2 and (len([e for e in entities if e.var_can_collide]) >= 2)
]
return positions
def check_move_validity(self, moving_entity, position):
if moving_entity.pos != position and not any(
entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]) and not (
moving_entity.var_is_blocking_pos and self.entities.is_occupied(position)):
return True
else:
return False
def check_move_validity(self, moving_entity: Entity, target_position: (int, int)) -> bool:
"""
Whether it is safe to move to the target positions and moving entity does not introduce a blocking attribute,
when position is allready occupied.
def check_pos_validity(self, position):
if not any(entity.var_is_blocking_pos for entity in self.entities.pos_dict[position]):
return True
else:
return False
:param moving_entity: Entity
:param target_position: pos
:return: Safe to move to
"""
is_not_blocked = self.check_pos_validity(target_position)
will_not_block_others = moving_entity.var_is_blocking_pos and self.entities.is_occupied(target_position)
if moving_entity.pos != target_position and is_not_blocked and not will_not_block_others:
return c.VALID
else:
return c.NOT_VALID
def check_pos_validity(self, pos: (int, int)) -> bool:
"""
Check if *pos* is a valid position to move or spawn to.
:param pos: position to check
:return: Wheter pos is a valid target.
"""
if not any(e.var_is_blocking_pos for e in self.entities.pos_dict[pos]) and pos in self.entities.floorlist:
return c.VALID
else:
return c.NOT_VALID

View File

@@ -28,7 +28,9 @@ class ConfigExplainer:
def explain_module(self, class_to_explain):
parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained
def _load_and_compare(self, compare_class, paths):