mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Smaller fixes, now running.
This commit is contained in:
parent
4fb32c98c6
commit
84eb381307
@ -74,13 +74,13 @@ class MonitorCallback(BaseCallback):
|
||||
dones = alt_dones
|
||||
elif self.locals.get('dones', None) is not None:
|
||||
dones =self.locals.get('dones', None)
|
||||
elif self.locals.get('dones', None) is not None:
|
||||
elif self.locals.get('done', None) is not None:
|
||||
dones = self.locals.get('done', [None])
|
||||
else:
|
||||
dones = []
|
||||
|
||||
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
||||
self._monitor_dicts[env_idx][self.num_timesteps - env_idx] = {key: val for key, val in info.items()
|
||||
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
|
||||
if key not in ['terminal_observation', 'episode']
|
||||
and not key.startswith('rec_')}
|
||||
if done:
|
||||
|
@ -34,7 +34,7 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
@ -42,5 +42,5 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order)
|
||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plot(filepath, ext=ext)
|
||||
|
23
main.py
23
main.py
@ -34,17 +34,19 @@ def combine_runs(run_path: Union[str, PathLike]):
|
||||
df_list.append(monitor_df)
|
||||
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
|
||||
roll_n = 50
|
||||
skip_n = 40
|
||||
|
||||
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
if df_melted['Episode'].max() > 100:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.01)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||
@ -71,13 +73,14 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
||||
columns = [col for col in df.columns if col in parameter]
|
||||
|
||||
roll_n = 40
|
||||
skip_n = 20
|
||||
|
||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||
|
||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
||||
value_vars=columns, var_name="Measurement",
|
||||
value_name="Score")
|
||||
if df_melted['Episode'].max() > 100:
|
||||
skip_n = round(df_melted['Episode'].max() * 0.01)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
style = 'Measurement' if len(columns) > 1 else None
|
||||
@ -113,7 +116,7 @@ if __name__ == '__main__':
|
||||
move_props = MovementProperties(allow_diagonal_movement=False,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
train_steps = 1e6
|
||||
train_steps = 1e5
|
||||
time_stamp = int(time.time())
|
||||
|
||||
out_path = None
|
||||
@ -131,12 +134,11 @@ if __name__ == '__main__':
|
||||
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
||||
)
|
||||
|
||||
# env = make_env(env_kwargs)()
|
||||
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn")
|
||||
|
||||
if modeL_type.__name__ in ["PPO", "A2C"]:
|
||||
kwargs = dict(ent_coef=0.01)
|
||||
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(6)], start_method="spawn")
|
||||
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
||||
env = make_env(env_kwargs)()
|
||||
kwargs = dict(buffer_size=50000,
|
||||
learning_starts=64,
|
||||
batch_size=64,
|
||||
@ -145,6 +147,7 @@ if __name__ == '__main__':
|
||||
exploration_final_eps=0.025)
|
||||
else:
|
||||
raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.')
|
||||
|
||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||
|
||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||
@ -165,7 +168,11 @@ if __name__ == '__main__':
|
||||
save_path = out_path / f'model_{identifier}.zip'
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
model.save(save_path)
|
||||
env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
|
||||
param_path = out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml'
|
||||
try:
|
||||
env.env_method('save_params', param_path)
|
||||
except AttributeError:
|
||||
env.save_params(param_path)
|
||||
print("Model Trained and saved")
|
||||
print("Model Group Done.. Plotting...")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user