mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Smaller fixes, now running.
This commit is contained in:
parent
4fb32c98c6
commit
84eb381307
@ -74,13 +74,13 @@ class MonitorCallback(BaseCallback):
|
|||||||
dones = alt_dones
|
dones = alt_dones
|
||||||
elif self.locals.get('dones', None) is not None:
|
elif self.locals.get('dones', None) is not None:
|
||||||
dones =self.locals.get('dones', None)
|
dones =self.locals.get('dones', None)
|
||||||
elif self.locals.get('dones', None) is not None:
|
elif self.locals.get('done', None) is not None:
|
||||||
dones = self.locals.get('done', [None])
|
dones = self.locals.get('done', [None])
|
||||||
else:
|
else:
|
||||||
dones = []
|
dones = []
|
||||||
|
|
||||||
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
||||||
self._monitor_dicts[env_idx][self.num_timesteps - env_idx] = {key: val for key, val in info.items()
|
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
|
||||||
if key not in ['terminal_observation', 'episode']
|
if key not in ['terminal_observation', 'episode']
|
||||||
and not key.startswith('rec_')}
|
and not key.startswith('rec_')}
|
||||||
if done:
|
if done:
|
||||||
|
@ -34,7 +34,7 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
|
|||||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||||
hue_order=hue_order, hue=hue, style=style)
|
hue_order=hue_order, hue=hue, style=style)
|
||||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||||
except (FileNotFoundError, RuntimeError):
|
except (FileNotFoundError, RuntimeError):
|
||||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||||
@ -42,5 +42,5 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
|
|||||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||||
ci=95, palette=PALETTE, hue_order=hue_order)
|
ci=95, palette=PALETTE, hue_order=hue_order)
|
||||||
lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||||
plot(filepath, ext=ext)
|
plot(filepath, ext=ext)
|
||||||
|
27
main.py
27
main.py
@ -34,18 +34,20 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
df_list.append(monitor_df)
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
df = pd.concat(df_list, ignore_index=True)
|
df = pd.concat(df_list, ignore_index=True)
|
||||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}).sort_values(['Run', 'Episode'])
|
||||||
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]
|
||||||
|
|
||||||
roll_n = 50
|
roll_n = 50
|
||||||
skip_n = 40
|
|
||||||
|
|
||||||
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
|
||||||
value_vars=columns, var_name="Measurement",
|
value_vars=columns, var_name="Measurement",
|
||||||
value_name="Score")
|
value_name="Score")
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
|
||||||
|
if df_melted['Episode'].max() > 100:
|
||||||
|
skip_n = round(df_melted['Episode'].max() * 0.01)
|
||||||
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||||
print('Plotting done.')
|
print('Plotting done.')
|
||||||
@ -71,14 +73,15 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
columns = [col for col in df.columns if col in parameter]
|
columns = [col for col in df.columns if col in parameter]
|
||||||
|
|
||||||
roll_n = 40
|
roll_n = 40
|
||||||
skip_n = 20
|
|
||||||
|
|
||||||
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
|
||||||
|
|
||||||
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
|
||||||
value_vars=columns, var_name="Measurement",
|
value_vars=columns, var_name="Measurement",
|
||||||
value_name="Score")
|
value_name="Score")
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
if df_melted['Episode'].max() > 100:
|
||||||
|
skip_n = round(df_melted['Episode'].max() * 0.01)
|
||||||
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
style = 'Measurement' if len(columns) > 1 else None
|
style = 'Measurement' if len(columns) > 1 else None
|
||||||
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
|
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
|
||||||
@ -113,7 +116,7 @@ if __name__ == '__main__':
|
|||||||
move_props = MovementProperties(allow_diagonal_movement=False,
|
move_props = MovementProperties(allow_diagonal_movement=False,
|
||||||
allow_square_movement=True,
|
allow_square_movement=True,
|
||||||
allow_no_op=False)
|
allow_no_op=False)
|
||||||
train_steps = 1e6
|
train_steps = 1e5
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
@ -131,12 +134,11 @@ if __name__ == '__main__':
|
|||||||
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# env = make_env(env_kwargs)()
|
|
||||||
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn")
|
|
||||||
|
|
||||||
if modeL_type.__name__ in ["PPO", "A2C"]:
|
if modeL_type.__name__ in ["PPO", "A2C"]:
|
||||||
kwargs = dict(ent_coef=0.01)
|
kwargs = dict(ent_coef=0.01)
|
||||||
|
env = SubprocVecEnv([make_env(env_kwargs) for _ in range(6)], start_method="spawn")
|
||||||
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
|
||||||
|
env = make_env(env_kwargs)()
|
||||||
kwargs = dict(buffer_size=50000,
|
kwargs = dict(buffer_size=50000,
|
||||||
learning_starts=64,
|
learning_starts=64,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
@ -145,6 +147,7 @@ if __name__ == '__main__':
|
|||||||
exploration_final_eps=0.025)
|
exploration_final_eps=0.025)
|
||||||
else:
|
else:
|
||||||
raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.')
|
raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.')
|
||||||
|
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||||
|
|
||||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||||
@ -165,7 +168,11 @@ if __name__ == '__main__':
|
|||||||
save_path = out_path / f'model_{identifier}.zip'
|
save_path = out_path / f'model_{identifier}.zip'
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
model.save(save_path)
|
model.save(save_path)
|
||||||
env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
|
param_path = out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml'
|
||||||
|
try:
|
||||||
|
env.env_method('save_params', param_path)
|
||||||
|
except AttributeError:
|
||||||
|
env.save_params(param_path)
|
||||||
print("Model Trained and saved")
|
print("Model Trained and saved")
|
||||||
print("Model Group Done.. Plotting...")
|
print("Model Group Done.. Plotting...")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user