diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 4134735..f3ca75a 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -189,6 +189,10 @@ class BaseFactory(gym.Env): if self.steps >= self.max_steps: done = True self.monitor.set('step_reward', reward) + self.monitor.set('step', self.steps) + + if done: + info.update(monitor=self.monitor) return self.state, reward, done, info def _is_moving_action(self, action): diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 1595a62..03a506f 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -145,28 +145,27 @@ class SimpleFactory(BaseFactory): self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') self.monitor.set('dirt_cleaned', 1) else: - reward -= 1 + reward -= 0.5 self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' f'at {agent_state.pos}, but was unsucsessfull.') self.monitor.set('failed_cleanup_attempt', 1) elif self._is_moving_action(agent_state.action): if agent_state.action_valid: - reward -= 0.01 + reward -= 0.00 else: reward -= 0.5 else: self.monitor.set('no_op', 1) - reward -= 0.25 + reward -= 0.1 for entity in cols: if entity != self.state_slices.by_name("dirt"): self.monitor.set(f'agent_{agent_state.i}_vs_{self.state_slices[entity]}', 1) self.monitor.set('dirt_amount', current_dirt_amount) - self.monitor.set('dirty_tiles', dirty_tiles) - self.monitor.set('step', self.steps) + self.monitor.set('dirty_tile_count', dirty_tiles) self.print(f"reward is {reward}") # Potential based rewards -> # track the last reward , minus the current reward = potential diff --git a/environments/helpers.py b/environments/helpers.py index 2f47848..c1ba774 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -8,6 +8,8 @@ LEVEL_IDX = 0 AGENT_START_IDX = 1 IS_FREE_CELL = 0 IS_OCCUPIED_CELL = 1 +TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles'] +IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index'] # Utility functions diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index ccfe1f7..4fc5ce2 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -4,7 +4,9 @@ from collections import defaultdict from stable_baselines3.common.callbacks import BaseCallback +from environments.helpers import IGNORED_DF_COLUMNS from environments.logging.plotting import prepare_plot +import pandas as pd class FactoryMonitor: @@ -59,16 +61,12 @@ class MonitorCallback(BaseCallback): def __init__(self, env, filepath=Path('debug_out/monitor.pick'), plotting=True): super(MonitorCallback, self).__init__() self.filepath = Path(filepath) - self._monitor_list = list() + self._monitor_df = pd.DataFrame() self.env = env self.plotting = plotting self.started = False self.closed = False - @property - def monitor_as_df_list(self): - return [x.to_pd_dataframe() for x in self._monitor_list] - def __enter__(self): self._on_training_start() @@ -89,11 +87,10 @@ class MonitorCallback(BaseCallback): else: # self.out_file.unlink(missing_ok=True) with self.filepath.open('wb') as f: - pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL) if self.plotting: print('Monitor files were dumped to disk, now plotting....') - # %% Imports - import pandas as pd + # %% Load MonitorList from Disk with self.filepath.open('rb') as f: monitor_list = pickle.load(f) @@ -111,14 +108,21 @@ class MonitorCallback(BaseCallback): if column != 'episode': df[f'{column}_roll'] = df[column].rolling(window=50).mean() # result.tail() - prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")), tag='monitor') + prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll"))) print('Plotting done.') self.closed = True def _on_step(self) -> bool: - if self.locals['dones'].item(): - self._monitor_list.append(self.env.monitor) - else: - pass + for env_idx, done in enumerate(self.locals.get('dones', [])): + if done: + env_monitor_df = self.locals['infos'][env_idx]['monitor'].to_pd_dataframe() + columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS] + env_monitor_df = env_monitor_df.aggregate( + {col: 'mean' if 'amount' in col or 'count' in col else 'sum' for col in columns} + ) + env_monitor_df['episode'] = len(self._monitor_df) + self._monitor_df = self._monitor_df.append([env_monitor_df]) + else: + pass return True diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 4613698..64cad3d 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -29,7 +29,7 @@ def plot(filepath, ext='png', **kwargs): figure.savefig(str(filepath), format=ext) -def prepare_plot(filepath, results_df, ext='png', tag=''): +def prepare_plot(filepath, results_df, ext='png'): _ = sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd') @@ -50,8 +50,7 @@ def prepare_plot(filepath, results_df, ext='png', tag=''): } try: - plot(filepath, ext=ext, tag=tag, **tex_fonts) + plot(filepath, ext=ext, **tex_fonts) except (FileNotFoundError, RuntimeError): tex_fonts['text.usetex'] = False - plot(filepath, ext=ext, tag=tag, **tex_fonts) - plt.show() + plot(filepath, ext=ext, **tex_fonts) diff --git a/environments/logging/training.py b/environments/logging/training.py index 0276851..d467981 100644 --- a/environments/logging/training.py +++ b/environments/logging/training.py @@ -32,7 +32,7 @@ class TraningMonitor(BaseCallback): df.to_csv(self.filepath, mode='a', header=False) def _on_step(self) -> bool: - for idx, done in np.ndenumerate(self.locals['dones']): + for idx, done in np.ndenumerate(self.locals.get('dones', [])): idx = idx[0] # self.values[self.num_timesteps].update(**{f'reward_env_{idx}': self.locals['rewards'][idx]}) self.rewards[idx] += self.locals['rewards'][idx] diff --git a/main.py b/main.py index 2b74805..15d70f9 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ import pandas as pd from stable_baselines3.common.callbacks import CallbackList from environments.factory.simple_factory import DirtProperties, SimpleFactory +from environments.helpers import IGNORED_DF_COLUMNS from environments.logging.monitor import MonitorCallback from environments.logging.plotting import prepare_plot from environments.logging.training import TraningMonitor @@ -22,16 +23,11 @@ def combine_runs(run_path: Union[str, PathLike]): df_list = list() for run, monitor_file in enumerate(run_path.rglob('monitor_*.pick')): with monitor_file.open('rb') as f: - monitor_list = pickle.load(f) + monitor_df = pickle.load(f) - for m_idx in range(len(monitor_list)): - monitor_list[m_idx]['episode'] = m_idx - monitor_list[m_idx]['run'] = run + monitor_df['run'] = run - df = pd.concat(monitor_list, ignore_index=True) - df['train_step'] = range(df.shape[0]) - - df = df.fillna(0) + monitor_df = monitor_df.fillna(0) #for column in list(df.columns): # if column not in ['episode', 'run', 'step', 'train_step']: @@ -40,20 +36,16 @@ def combine_runs(run_path: Union[str, PathLike]): # else: # df[f'{column}_mean_roll'] = df[column].rolling(window=50, min_periods=1).mean() - df_list.append(df) + df_list.append(monitor_df) df = pd.concat(df_list, ignore_index=True) df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) - columns = [col for col in df.columns if col not in ['Episode', 'Run', 'train_step', 'step']] + columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS] - df_group = df.groupby(['Episode', 'Run']).aggregate( - {col: 'mean' if col in ['dirt_amount', 'dirty_tiles'] else 'sum' for col in columns} - ) + non_overlapp_window = df.groupby(['Run', df['Episode'] // 20]).mean() - non_overlapp_window = df_group.groupby(['Run', (df_group.index.get_level_values('Episode') // 20)]).mean() - - df_melted = non_overlapp_window.reset_index().melt(id_vars=['Episode', 'Run'], - value_vars=columns, var_name="Measurement", - value_name="Score") + df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'], + value_vars=columns, var_name="Measurement", + value_name="Score") prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) print('Plotting done.') @@ -61,36 +53,38 @@ def combine_runs(run_path: Union[str, PathLike]): if __name__ == '__main__': - # combine_runs('debug_out/PPO_1622120377') + # combine_runs('debug_out/PPO_1622399010') # exit() - from stable_baselines3 import PPO # DQN + from stable_baselines3 import PPO, DQN dirt_props = DirtProperties() time_stamp = int(time.time()) out_path = None - for seed in range(5): + for modeL_type in [PPO]: + for seed in range(5): - env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, allow_diagonal_movement=True, allow_no_op=False) + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, + allow_diagonal_movement=False, allow_no_op=False) - model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu') + model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') - out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' + out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' - identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - out_path /= identifier + identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + out_path /= identifier - callbacks = CallbackList( - [TraningMonitor(out_path / f'train_logging_{identifier}.csv'), - MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] - ) + callbacks = CallbackList( + [TraningMonitor(out_path / f'train_logging_{identifier}.csv'), + MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] + ) - model.learn(total_timesteps=int(2e6), callback=callbacks) + model.learn(total_timesteps=int(5e5), callback=callbacks) - save_path = out_path / f'model_{identifier}.zip' - save_path.parent.mkdir(parents=True, exist_ok=True) - model.save(save_path) + save_path = out_path / f'model_{identifier}.zip' + save_path.parent.mkdir(parents=True, exist_ok=True) + model.save(save_path) - if out_path: - combine_runs(out_path) + if out_path: + combine_runs(out_path.parent)