mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 01:21:36 +02:00
Merge branch 'main' into 'unit_testing'
# Conflicts: # marl_factory_grid/algorithms/static/TSP_dirt_agent.py # marl_factory_grid/utils/config_parser.py
This commit is contained in:
@ -1,3 +1,11 @@
|
||||
from . import helpers as h
|
||||
from . import helpers
|
||||
from .results import Result, DoneResult, ActionResult, TickResult
|
||||
|
||||
"""
|
||||
Utils
|
||||
=====
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -22,6 +22,12 @@ class FactoryConfigParser(object):
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
|
||||
"""
|
||||
This class parses the factory env config file.
|
||||
|
||||
:param config_path: Path to where the 'config.yml' is.
|
||||
:param custom_modules_path: Additional search path for custom modules, levels, entities, etc..
|
||||
"""
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
@ -40,7 +46,6 @@ class FactoryConfigParser(object):
|
||||
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
|
||||
return self._n_abbr_dict[n]
|
||||
|
||||
|
||||
@property
|
||||
def agent_actions(self):
|
||||
return self._get_sub_list('Agents', "Actions")
|
||||
@ -129,13 +134,25 @@ class FactoryConfigParser(object):
|
||||
# Actions
|
||||
conf_actions = self.agents[name]['Actions']
|
||||
actions = list()
|
||||
# Actions:
|
||||
# Allowed
|
||||
# - Noop
|
||||
# - Move8
|
||||
# ----
|
||||
# Noop:
|
||||
# South:
|
||||
# reward_fail: 0.5
|
||||
# ----
|
||||
# Forbidden
|
||||
# - South:
|
||||
# reward_fail: 0.5
|
||||
|
||||
if isinstance(conf_actions, dict):
|
||||
conf_kwargs = conf_actions.copy()
|
||||
conf_actions = list(conf_actions.keys())
|
||||
elif isinstance(conf_actions, list):
|
||||
conf_kwargs = {}
|
||||
if isinstance(conf_actions, dict):
|
||||
if any(isinstance(x, dict) for x in conf_actions):
|
||||
raise ValueError
|
||||
pass
|
||||
for action in conf_actions:
|
||||
@ -152,11 +169,10 @@ class FactoryConfigParser(object):
|
||||
except AttributeError:
|
||||
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
||||
try:
|
||||
# print(action)
|
||||
# Handle Lists of Actions (e.g., Move8, Move4, Default)
|
||||
parsed_actions.extend(class_or_classes)
|
||||
# print(parsed_actions)
|
||||
for actions_class in class_or_classes:
|
||||
# break
|
||||
conf_kwargs[actions_class.__name__] = conf_kwargs.get(action, {})
|
||||
except TypeError:
|
||||
parsed_actions.append(class_or_classes)
|
||||
@ -174,7 +190,7 @@ class FactoryConfigParser(object):
|
||||
['Actions', 'Observations', 'Positions', 'Clones']}
|
||||
parsed_agents_conf[name] = dict(
|
||||
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
clones = self.agents[name].get('Clones', 0)
|
||||
if clones:
|
||||
|
@ -10,14 +10,14 @@ from marl_factory_grid.environment import constants as c
|
||||
|
||||
"""
|
||||
This file is used for:
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
@ -54,15 +54,9 @@ class ObservationTranslator:
|
||||
A string _identifier based approach is used.
|
||||
Currently, it is not possible to mix different obs shapes.
|
||||
|
||||
|
||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||
:type this_named_observation_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently, not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N'
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
|
@ -16,7 +16,10 @@ class LevelParser(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
Internal Usage
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
return self.pomdp_r * 2 + 1
|
||||
|
||||
|
@ -0,0 +1,7 @@
|
||||
"""
|
||||
logging
|
||||
=======
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -17,6 +17,9 @@ class EnvMonitor(Wrapper):
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
"""
|
||||
EnvMonitor is a wrapper for Gymnasium environments that monitors and logs key information during interactions.
|
||||
"""
|
||||
super(EnvMonitor, self).__init__(env)
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
@ -52,6 +55,14 @@ class EnvMonitor(Wrapper):
|
||||
return
|
||||
|
||||
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
"""
|
||||
Saves the monitoring data to a file and optionally generates plots.
|
||||
|
||||
:param filepath: The path to save the monitoring data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param auto_plotting_keys: Keys to use for automatic plot generation.
|
||||
:type auto_plotting_keys: Any
|
||||
"""
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
|
@ -11,6 +11,16 @@ class EnvRecorder(Wrapper):
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None,
|
||||
episodes: Union[List[int], None] = None):
|
||||
"""
|
||||
EnvRecorder is a wrapper for OpenAI Gym environments that records state summaries during interactions.
|
||||
|
||||
:param env: The environment to record.
|
||||
:type env: gym.Env
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[str, PathLike]
|
||||
:param episodes: A list of episode numbers to record. If None, records all episodes.
|
||||
:type episodes: Union[List[int], None]
|
||||
"""
|
||||
super(EnvRecorder, self).__init__(env)
|
||||
self.filepath = filepath
|
||||
self.episodes = episodes
|
||||
@ -19,6 +29,9 @@ class EnvRecorder(Wrapper):
|
||||
self._recorder_out_list = list()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Overrides the reset method to reset the environment and recording lists.
|
||||
"""
|
||||
self._curr_ep_recorder = list()
|
||||
self._recorder_out_list = list()
|
||||
self._curr_episode += 1
|
||||
@ -26,10 +39,12 @@ class EnvRecorder(Wrapper):
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Todo
|
||||
Overrides the step method to record state summaries during each step.
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
:param actions: The actions taken in the environment.
|
||||
:type actions: Any
|
||||
:return: The observation, reward, done flag, and additional information.
|
||||
:rtype: Tuple
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
@ -55,6 +70,18 @@ class EnvRecorder(Wrapper):
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
"""
|
||||
Saves the recorded data to a file.
|
||||
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param only_deltas: If True, saves only the differences between consecutive episodes.
|
||||
:type only_deltas: bool
|
||||
:param save_occupation_map: If True, saves an occupation map as a heatmap.
|
||||
:type save_occupation_map: bool
|
||||
:param save_trajectory_map: If True, saves a trajectory map.
|
||||
:type save_trajectory_map: bool
|
||||
"""
|
||||
self._finalize()
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
@ -73,7 +100,6 @@ class EnvRecorder(Wrapper):
|
||||
n_dests=0,
|
||||
dwell_time=0,
|
||||
spawn_frequency=0,
|
||||
spawn_in_other_zone=False,
|
||||
spawn_mode=''
|
||||
)
|
||||
rewards_dest = dict(
|
||||
|
@ -19,10 +19,10 @@ class OBSBuilder(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
TODO
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
@ -31,10 +31,17 @@ class OBSBuilder(object):
|
||||
|
||||
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
|
||||
"""
|
||||
TODO
|
||||
OBSBuilder
|
||||
==========
|
||||
|
||||
The OBSBuilder class is responsible for constructing observations in the environment.
|
||||
|
||||
:return:
|
||||
:param level_shape: The shape of the level or environment.
|
||||
:type level_shape: np.size
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.environment.state.Gamestate
|
||||
:param pomdp_r: The POMDP radius, influencing the size of the observation space.
|
||||
:type pomdp_r: int
|
||||
"""
|
||||
self.all_obs = dict()
|
||||
self.ray_caster = dict()
|
||||
@ -52,6 +59,9 @@ class OBSBuilder(object):
|
||||
self.reset(state)
|
||||
|
||||
def reset(self, state):
|
||||
"""
|
||||
Resets temporary information and constructs an empty observation array with possible placeholders.
|
||||
"""
|
||||
# Reset temporary information
|
||||
self.curr_lightmaps = dict()
|
||||
# Construct an empty obs (array) for possible placeholders
|
||||
@ -61,6 +71,11 @@ class OBSBuilder(object):
|
||||
return True
|
||||
|
||||
def observation_space(self, state):
|
||||
"""
|
||||
Returns the observation space for a single agent or a tuple of spaces for multiple agents.
|
||||
:returns: The observation space for the agent(s).
|
||||
:rtype: gym.Space|Tuple
|
||||
"""
|
||||
from gymnasium.spaces import Tuple, Box
|
||||
self.reset(state)
|
||||
obsn = self.build_for_all(state)
|
||||
@ -71,13 +86,29 @@ class OBSBuilder(object):
|
||||
return space
|
||||
|
||||
def named_observation_space(self, state):
|
||||
"""
|
||||
:returns: A dictionary of named observation spaces for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
self.reset(state)
|
||||
return self.build_for_all(state)
|
||||
|
||||
def build_for_all(self, state) -> (dict, dict):
|
||||
"""
|
||||
Builds observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary of observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
|
||||
|
||||
def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
|
||||
"""
|
||||
Builds named observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary containing named observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
named_obs_dict = {}
|
||||
for agent in state[c.AGENT]:
|
||||
obs, names = self.build_for_agent(agent, state)
|
||||
@ -85,6 +116,16 @@ class OBSBuilder(object):
|
||||
return named_obs_dict
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
"""
|
||||
Places the encoding of an entity in the observation array relative to the agent's position.
|
||||
|
||||
:param obs_array: The observation array.
|
||||
:type obs_array: np.ndarray
|
||||
:param agent: the associated agent
|
||||
:type agent: Agent
|
||||
:param e: The entity to be placed in the observation.
|
||||
:type e: Entity
|
||||
"""
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
if not min([y, x]) < 0:
|
||||
try:
|
||||
@ -95,6 +136,12 @@ class OBSBuilder(object):
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
"""
|
||||
Builds observations for a specific agent.
|
||||
|
||||
:returns: A tuple containing a list of observation names and the corresponding observation array
|
||||
:rtype: Tuple[List[str], np.ndarray]
|
||||
"""
|
||||
try:
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
except KeyError:
|
||||
@ -190,8 +237,8 @@ class OBSBuilder(object):
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
|
||||
:param agent: The agent for whom the observation scheme is built.
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
|
@ -0,0 +1,7 @@
|
||||
"""
|
||||
PLotting
|
||||
========
|
||||
|
||||
Todo
|
||||
|
||||
"""
|
@ -13,6 +13,16 @@ MODEL_MAP = None
|
||||
|
||||
|
||||
def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
"""
|
||||
|
||||
Compare multiple runs with different seeds by generating a line plot that shows the evolution of scores (step rewards)
|
||||
across episodes.
|
||||
|
||||
:param run_path: The path to the directory containing the monitor files for each run.
|
||||
:type run_path: Union[str, PathLike]
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
for run, monitor_file in enumerate(run_path.rglob('monitor*.pick')):
|
||||
@ -23,7 +33,7 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
@ -49,6 +59,19 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
|
||||
def compare_model_runs(run_path: Path, run_identifier: Union[str, int], parameter: Union[str, List[str]],
|
||||
use_tex: bool = False):
|
||||
"""
|
||||
Compares multiple model runs based on specified parameters by generating a line plot showing the evolution of scores (step rewards)
|
||||
across episodes.
|
||||
|
||||
:param run_path: The path to the directory containing the monitor files for each model run.
|
||||
:type run_path: Path
|
||||
:param run_identifier: A string or integer identifying the runs to compare.
|
||||
:type run_identifier: Union[str, int]
|
||||
:param parameter: A single parameter or a list of parameters to compare.
|
||||
:type parameter: Union[str, List[str]]
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
@ -89,6 +112,20 @@ def compare_model_runs(run_path: Path, run_identifier: Union[str, int], paramete
|
||||
|
||||
def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[str]],
|
||||
param_names: Union[List[str], None] = None, str_to_ignore='', use_tex: bool = False):
|
||||
"""
|
||||
Compares model runs across different parameter settings by generating a line plot showing the evolution of scores across episodes.
|
||||
|
||||
:param run_root_path: The root path to the directory containing the monitor files for all model runs.
|
||||
:type run_root_path: Path
|
||||
:param parameter: The parameter(s) to compare across different runs.
|
||||
:type parameter: Union[str, List[str]]
|
||||
:param param_names: A list of custom names for the parameters to be used as labels in the plot. If None, default names will be assigned.
|
||||
:type param_names: Union[List[str], None]
|
||||
:param str_to_ignore: A string pattern to ignore in parameter names.
|
||||
:type str_to_ignore: str
|
||||
:param use_tex: A boolean indicating whether to use TeX formatting in the plot.
|
||||
:type use_tex: bool
|
||||
"""
|
||||
run_root_path = Path(run_root_path)
|
||||
df_list = list()
|
||||
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||
|
@ -10,7 +10,21 @@ from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str ='monitor', file_ext: str ='pkl'):
|
||||
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
||||
"""
|
||||
Plots the Epoch score (step reward) over a single run based on monitoring data stored in a file.
|
||||
|
||||
:param run_path: The path to the directory containing monitoring data or directly to the monitoring file.
|
||||
:type run_path: Union[str, PathLike]
|
||||
:param use_tex: Flag indicating whether to use TeX for plotting.
|
||||
:type use_tex: bool, optional
|
||||
:param column_keys: Specific columns to include in the plot. If None, includes all columns except ignored ones.
|
||||
:type column_keys: list or None, optional
|
||||
:param file_key: The keyword to identify the monitoring file.
|
||||
:type file_key: str, optional
|
||||
:param file_ext: The extension of the monitoring file.
|
||||
:type file_ext: str, optional
|
||||
"""
|
||||
run_path = Path(run_path)
|
||||
df_list = list()
|
||||
if run_path.is_dir():
|
||||
@ -26,7 +40,7 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode'}).sort_values(['Episode'])
|
||||
if column_keys is not None:
|
||||
columns = [col for col in column_keys if col in df.columns]
|
||||
|
@ -1,7 +1,6 @@
|
||||
import seaborn as sns
|
||||
import matplotlib as mpl
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
@ -20,6 +19,14 @@ PALETTE = 10 * (
|
||||
|
||||
|
||||
def plot(filepath, ext='png'):
|
||||
"""
|
||||
Saves the current plot to a file and displays it.
|
||||
|
||||
:param filepath: The path to save the plot file.
|
||||
:type filepath: str
|
||||
:param ext: The file extension of the saved plot. Default is 'png'.
|
||||
:type ext: str
|
||||
"""
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
ax = plt.gca()
|
||||
@ -35,6 +42,20 @@ def plot(filepath, ext='png'):
|
||||
|
||||
|
||||
def prepare_tex(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot for rendering in LaTeX.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
@ -45,6 +66,20 @@ def prepare_tex(df, hue, style, hue_order):
|
||||
|
||||
|
||||
def prepare_plt(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot using matplotlib.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
@ -57,6 +92,20 @@ def prepare_plt(df, hue, style, hue_order):
|
||||
|
||||
|
||||
def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
"""
|
||||
Prepares a line plot with a legend centered at the bottom and spread across two columns.
|
||||
|
||||
:param df: The DataFrame containing the data to be plotted.
|
||||
:type df: pandas.DataFrame
|
||||
:param hue: Grouping variable that will produce lines with different colors.
|
||||
:type hue: str
|
||||
:param style: Grouping variable that will produce lines with different styles.
|
||||
:type style: str
|
||||
:param hue_order: Order for the levels of the hue variable in the plot.
|
||||
:type hue_order: list
|
||||
:return: The prepared line plot.
|
||||
:rtype: matplotlib.axes._subplots.AxesSubplot
|
||||
"""
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
@ -70,6 +119,23 @@ def prepare_center_double_column_legend(df, hue, style, hue_order):
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None, use_tex=False):
|
||||
"""
|
||||
Prepares a line plot for visualization. Based on the use tex parameter calls the prepare_tex or prepare_plot
|
||||
function accordingly, followed by the plot function to save the plot.
|
||||
|
||||
:param filepath: The file path where the plot will be saved.
|
||||
:type filepath: str
|
||||
:param results_df: The DataFrame containing the data to be plotted.
|
||||
:type results_df: pandas.DataFrame
|
||||
:param ext: The file extension of the saved plot (default is 'png').
|
||||
:type ext: str
|
||||
:param hue: The variable to determine the color of the lines in the plot.
|
||||
:type hue: str
|
||||
:param style: The variable to determine the style of the lines in the plot (default is None).
|
||||
:type style: str or None
|
||||
:param use_tex: Whether to use LaTeX for text rendering (default is False).
|
||||
:type use_tex: bool
|
||||
"""
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
hue_order = sorted(list(df[hue].unique()))
|
||||
|
@ -8,10 +8,17 @@ from numba import njit
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
"""
|
||||
TODO
|
||||
The RayCaster class enables agents in the environment to simulate field-of-view visibility,
|
||||
providing methods for calculating visible entities and outlining the field of view based on
|
||||
Bresenham's algorithm.
|
||||
|
||||
|
||||
:return:
|
||||
:param agent: The agent for which the RayCaster is initialized.
|
||||
:type agent: Agent
|
||||
:param pomdp_r: The range of the partially observable Markov decision process (POMDP).
|
||||
:type pomdp_r: int
|
||||
:param degs: The degrees of the field of view (FOV). Defaults to 360.
|
||||
:type degs: int
|
||||
:return: None
|
||||
"""
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
@ -25,6 +32,12 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
"""
|
||||
Builds the targets for the rays based on the field of view (FOV).
|
||||
|
||||
:return: The targets for the rays.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
@ -36,11 +49,31 @@ class RayCaster:
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, key, callback):
|
||||
"""
|
||||
Retrieves or caches a value in the cache dictionary.
|
||||
|
||||
:param key: The key for the cache dictionary.
|
||||
:type key: any
|
||||
:param callback: The callback function to obtain the value if not present in the cache.
|
||||
:type callback: callable
|
||||
:return: The cached or newly computed value.
|
||||
:rtype: any
|
||||
"""
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
"""
|
||||
Returns a list of visible entities based on the agent's field of view.
|
||||
|
||||
:param pos_dict: The dictionary containing positions of entities.
|
||||
:type pos_dict: dict
|
||||
:param reset_cache: Flag to reset the cache. Defaults to True.
|
||||
:type reset_cache: bool
|
||||
:return: A list of visible entities.
|
||||
:rtype: list
|
||||
"""
|
||||
visible = list()
|
||||
if reset_cache:
|
||||
self._cache_dict = dict()
|
||||
@ -71,15 +104,33 @@ class RayCaster:
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
"""
|
||||
Gets the rays for the agent.
|
||||
|
||||
:return: The rays for the agent.
|
||||
:rtype: list
|
||||
"""
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
"""
|
||||
Gets the field of view (FOV) outline.
|
||||
|
||||
:return: The FOV outline.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
"""
|
||||
Gets the square outline for the agent.
|
||||
|
||||
:return: The square outline.
|
||||
:rtype: list
|
||||
"""
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
@ -90,6 +141,16 @@ class RayCaster:
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
"""
|
||||
Applies Bresenham's algorithm to calculate the points between two positions.
|
||||
|
||||
:param a_pos: The starting position.
|
||||
:type a_pos: list
|
||||
:param points: The ending positions.
|
||||
:type points: list
|
||||
:return: The list of points between the starting and ending positions.
|
||||
:rtype: list
|
||||
"""
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
|
@ -34,12 +34,26 @@ class Renderer:
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
"""
|
||||
TODO
|
||||
The Renderer class initializes and manages the rendering environment for the simulation,
|
||||
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
|
||||
rendering the entities on the screen with specified parameters.
|
||||
|
||||
|
||||
:return:
|
||||
:param lvl_shape: Tuple representing the shape of the level.
|
||||
:type lvl_shape: Tuple[int, int]
|
||||
:param lvl_padded_shape: Optional Tuple representing the padded shape of the level.
|
||||
:type lvl_padded_shape: Union[Tuple[int, int], None]
|
||||
:param cell_size: Size of each cell in pixels.
|
||||
:type cell_size: int
|
||||
:param fps: Frames per second for rendering.
|
||||
:type fps: int
|
||||
:param factor: Factor for resizing assets.
|
||||
:type factor: float
|
||||
:param grid_lines: Boolean indicating whether to display grid lines.
|
||||
:type grid_lines: bool
|
||||
:param view_radius: Radius for agent's field of view.
|
||||
:type view_radius: int
|
||||
"""
|
||||
# TODO: Customn_assets paths
|
||||
# TODO: Custom_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -60,6 +74,9 @@ class Renderer:
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
"""
|
||||
Fills the background of the screen with the specified BG color.
|
||||
"""
|
||||
self.screen.fill(Renderer.BG_COLOR)
|
||||
if self.grid_lines:
|
||||
w, h = self.screen_size
|
||||
@ -69,6 +86,16 @@ class Renderer:
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity):
|
||||
"""
|
||||
Prepares parameters for blitting an entity on the screen. Blitting refers to the process of combining or copying
|
||||
rectangular blocks of pixels from one part of a graphical buffer to another and is often used to efficiently
|
||||
update the display by copying pre-drawn or cached images onto the screen.
|
||||
|
||||
:param entity: The entity to be blitted.
|
||||
:type entity: Entity
|
||||
:return: Dictionary containing source and destination information for blitting.
|
||||
:rtype: dict
|
||||
"""
|
||||
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
@ -90,12 +117,31 @@ class Renderer:
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
"""
|
||||
Loads and resizes an asset from the specified path.
|
||||
|
||||
:param path: Path to the asset.
|
||||
:type path: str
|
||||
:param factor: Resizing factor for the asset.
|
||||
:type factor: float
|
||||
:return: Resized asset.
|
||||
"""
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
"""
|
||||
Calculates the visibility rectangles for an agent.
|
||||
|
||||
:param bp: Blit parameters for the agent.
|
||||
:type bp: dict
|
||||
:param view: Agent's field of view.
|
||||
:type view: np.ndarray
|
||||
:return: List of visibility rectangles.
|
||||
:rtype: List[dict]
|
||||
"""
|
||||
rects = []
|
||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
||||
range(-self.view_radius, self.view_radius+1)):
|
||||
@ -111,6 +157,14 @@ class Renderer:
|
||||
return rects
|
||||
|
||||
def render(self, entities):
|
||||
"""
|
||||
Renders the entities on the screen.
|
||||
|
||||
:param entities: List of entities to be rendered.
|
||||
:type entities: List[Entity]
|
||||
:return: Transposed RGB observation array.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
|
@ -15,10 +15,12 @@ from marl_factory_grid.utils.results import Result
|
||||
class StepRules:
|
||||
def __init__(self, *args):
|
||||
"""
|
||||
TODO
|
||||
Manages a collection of rules to be applied at each step of the environment.
|
||||
|
||||
The StepRules class allows you to organize and apply custom rules during the simulation, ensuring that the
|
||||
corresponding hooks for all rules are called at the appropriate times.
|
||||
|
||||
:return:
|
||||
:param args: Optional Rule objects to initialize the StepRules with.
|
||||
"""
|
||||
if args:
|
||||
self.rules = list(args)
|
||||
@ -92,10 +94,18 @@ class Gamestate(object):
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
|
||||
"""
|
||||
TODO
|
||||
The `Gamestate` class represents the state of the game environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param lvl_shape: The shape of the game level.
|
||||
:type lvl_shape: tuple
|
||||
:param entities: The entities present in the environment.
|
||||
:type entities: Entities
|
||||
:param agents_conf: Agent configurations for the environment.
|
||||
:type agents_conf: Any
|
||||
:param verbose: Controls verbosity in the environment.
|
||||
:type verbose: bool
|
||||
:param rules: Organizes and applies custom rules during the simulation.
|
||||
:type rules: StepRules
|
||||
"""
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
@ -162,7 +172,7 @@ class Gamestate(object):
|
||||
|
||||
def tick(self, actions) -> list[Result]:
|
||||
"""
|
||||
Performs a single **Gamestate Tick**by calling the inner rule hooks in sequential order.
|
||||
Performs a single **Gamestate Tick** by calling the inner rule hooks in sequential order.
|
||||
- tick_pre_step_all: Things to do before the agents do their actions. Statechange, Moving, Spawning etc...
|
||||
- agent tick: Agents do their actions.
|
||||
- tick_step_all: Things to do after the agents did their actions. Statechange, Moving, Spawning etc...
|
||||
|
@ -15,7 +15,7 @@ OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
TESTS = 'Tests'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls', 'Gamestate', 'Path',
|
||||
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Zones', 'Collection',
|
||||
'Iterable', 'Move', 'Result', 'TemplateRule', 'Entities', 'EnvObjects', 'Collection',
|
||||
'State', 'Object', 'default_valid_reward', 'default_fail_reward', 'size']
|
||||
|
||||
|
||||
|
@ -6,7 +6,10 @@ import numpy as np
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
"""
|
||||
|
||||
todo @romue404
|
||||
"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
|
Reference in New Issue
Block a user