mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-18 00:21:58 +02:00
major redesign ob observations and entittes
This commit is contained in:
@@ -6,9 +6,11 @@ from typing import Union, List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from environments.helpers import IGNORED_DF_COLUMNS, MODEL_MAP
|
||||
from environment.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from plotting.plotting import prepare_plot
|
||||
|
||||
MODEL_MAP = None
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None):
|
||||
run_path = Path(run_path)
|
||||
@@ -37,9 +39,9 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
|
||||
|
||||
non_overlapp_window = df.groupby(['Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(id_vars=['Episode'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
df_melted = df[columns + ['Episode']].reset_index().melt(
|
||||
id_vars=['Episode'], value_vars=columns, var_name="Measurement", value_name="Score"
|
||||
)
|
||||
|
||||
if df_melted['Episode'].max() > 800:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
@@ -133,22 +135,22 @@ def compare_all_parameter_runs(run_root_path: Path, parameter: Union[str, List[s
|
||||
with monitor_file.open('rb') as f:
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
parameters = [x.name for x in monitor_file.parents if x.parent not in run_root_path.parents]
|
||||
params = [x.name for x in monitor_file.parents if x.parent not in run_root_path.parents]
|
||||
if str_to_ignore:
|
||||
parameters = [re.sub(f'_*({str_to_ignore})', '', param) for param in parameters]
|
||||
params = [re.sub(f'_*({str_to_ignore})', '', param) for param in params]
|
||||
|
||||
if monitor_idx == 0:
|
||||
if param_names is not None:
|
||||
if len(param_names) < len(parameters):
|
||||
if len(param_names) < len(params):
|
||||
# FIXME: Missing Seed Detection, see below @111
|
||||
param_names = [next(param_names) if param not in MODEL_MAP.keys() else 'Model' for param in parameters]
|
||||
elif len(param_names) == len(parameters):
|
||||
param_names = [next(param_names) if param not in MODEL_MAP.keys() else 'Model' for param in params]
|
||||
elif len(param_names) == len(params):
|
||||
pass
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
param_names = []
|
||||
for param_idx, param in enumerate(parameters):
|
||||
for param_idx, param in enumerate(params):
|
||||
dtype = None
|
||||
if param in MODEL_MAP.keys():
|
||||
param_name = 'Model'
|
||||
|
Reference in New Issue
Block a user