From 36fe59c95c49d3e278d5e36eaec47ccd1b22a153 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Fri, 28 May 2021 14:43:03 +0200 Subject: [PATCH] zwischenstand, no checkout pls! --- environments/factory/base_factory.py | 57 +++++++++++++++++++++++----- environments/logging/plotting.py | 3 +- main.py | 38 +++++++++---------- reload_agent.py | 32 ++++++++++++++++ 4 files changed, 100 insertions(+), 30 deletions(-) create mode 100644 reload_agent.py diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 5ac6674..4661552 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,4 +1,4 @@ -from typing import List, Union, Iterable +from typing import List, Union, Iterable, TypedDict import gym from gym import spaces @@ -32,6 +32,37 @@ class AgentState: raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}') +class Actions: + + def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True): + self.allow_no_OP = allow_no_OP + self.allow_diagonal_movement = allow_diagonal_movement + self.allow_square_movement = allow_square_movement + self._registerd_actions = dict() + if allow_square_movement: + self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])} + if allow_diagonal_movement: + self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])} + + self._movement_actions = self._registerd_actions.copy() + if self.allow_no_OP: + self + {0:'no-op'} + + + def __len__(self): + return len(self._registerd_actions) + + def __add__(self, other: dict): + assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.' + assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.' + self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()}) + return self + + def register_additional_actions(self, other:dict): + self_with_additional_actions = self + other + return self_with_additional_actions + + class BaseFactory(gym.Env): @property @@ -44,7 +75,16 @@ class BaseFactory(gym.Env): @property def movement_actions(self): - return (int(self.allow_square_movement) + int(self.allow_diagonal_movement)) * 4 + if self._movement_actions is None: + self._movement_actions = dict() + if self.allow_square_movement: + self._movement_actions.update( + ) + if self.allow_diagonal_movement: + self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])} + + return self._movement_actions + @property def string_slices(self): @@ -53,18 +93,17 @@ class BaseFactory(gym.Env): def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)): self.n_agents = n_agents self.max_steps = max_steps - self.allow_square_movement = True - self.allow_diagonal_movement = True - self.allow_no_OP = True self.done_at_collision = False - self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() + self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False) + + self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') ) self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}} self.reset() - def register_additional_actions(self) -> int: + def register_additional_actions(self) -> dict: raise NotImplementedError('Please register additional actions ') def reset(self) -> (np.ndarray, int, bool, dict): @@ -123,10 +162,10 @@ class BaseFactory(gym.Env): return self.state, reward, done, info def _is_moving_action(self, action): - return action < self.movement_actions + return self._registered_actions[action] in self.movement_actions def _is_no_op(self, action): - return self.allow_no_OP and (action - self.movement_actions) == 0 + return self._registered_actions[action] == 'no-op' def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray: collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 6a40561..a9726d7 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -30,9 +30,8 @@ def plot(filepath, ext='png', tag='monitor', **kwargs): def prepare_plot(filepath, results_df, ext='png', tag=''): - # %% - _ = sns.lineplot(data=results_df, ci='sd', x='step') + _ = sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd') # %% sns.set_theme(palette=PALETTE, style='whitegrid') diff --git a/main.py b/main.py index bf3b430..1aedd46 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ from os import PathLike from pathlib import Path import time import pandas as pd +from natsort import natsorted from stable_baselines3.common.callbacks import CallbackList @@ -25,8 +26,8 @@ def combine_runs(run_path: Union[str, PathLike]): monitor_list = pickle.load(f) for m_idx in range(len(monitor_list)): - monitor_list[m_idx]['episode'] = str(m_idx) - monitor_list[m_idx]['run'] = str(run) + monitor_list[m_idx]['episode'] = m_idx + monitor_list[m_idx]['run'] = run df = pd.concat(monitor_list, ignore_index=True) df['train_step'] = range(df.shape[0]) @@ -42,31 +43,30 @@ def combine_runs(run_path: Union[str, PathLike]): df_list.append(df) df = pd.concat(df_list, ignore_index=True) - df = df.fillna(0) + df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'}) - df_group = df.groupby(['episode', 'run']).aggregate({col: 'mean' if col in ['dirt_amount', + df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'mean' if col in ['dirt_amount', 'dirty_tiles'] else 'sum' - for col in df.columns if col not in ['episode', 'run'] - }).reset_index() + for col in df.columns if + col not in ['Episode', 'Run', 'train_step'] + }) + non_overlapp_window = df_group.groupby(['Run', (df_group.index.get_level_values('Episode') // 50)]).mean() - import seaborn as sns - from matplotlib import pyplot as plt - df_melted = df_group.melt(id_vars=['episode', 'run'], - value_vars=['agent_0_vs_level', 'dirt_amount', - 'dirty_tiles', 'step_reward', - 'failed_cleanup_attempt', - 'dirt_cleaned'], var_name="Variable", - value_name="Score") + df_melted = non_overlapp_window.reset_index().melt(id_vars=['Episode', 'Run'], + value_vars=['agent_0_vs_level', 'dirt_amount', + 'dirty_tiles', 'step_reward', + 'failed_cleanup_attempt', + 'dirt_cleaned'], var_name="Measurement", + value_name="Score") - sns.lineplot(data=df_melted, x='episode', y='Score', hue='Variable', ci='sd') - plt.show() + prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) print('Plotting done.') if __name__ == '__main__': - combine_runs('debug_out/PPO_1622120377') + combine_runs('debug_out/PPO_1622128912') exit() from stable_baselines3 import DQN, PPO @@ -82,7 +82,7 @@ if __name__ == '__main__': model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu') - out_path = Path('../debug_out') / f'{model.__class__.__name__}_{time_stamp}' + out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' out_path /= identifier @@ -92,7 +92,7 @@ if __name__ == '__main__': MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] ) - model.learn(total_timesteps=int(5e5), callback=callbacks) + model.learn(total_timesteps=int(2e6), callback=callbacks) save_path = out_path / f'model_{identifier}.zip' save_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/reload_agent.py b/reload_agent.py new file mode 100644 index 0000000..00dbc79 --- /dev/null +++ b/reload_agent.py @@ -0,0 +1,32 @@ +import warnings +from pathlib import Path +import time + +from natsort import natsorted +from stable_baselines3 import PPO +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.callbacks import CallbackList +from stable_baselines3.common.evaluation import evaluate_policy + +from environments.factory.simple_factory import DirtProperties, SimpleFactory +from environments.logging.monitor import MonitorCallback +from environments.logging.training import TraningMonitor + +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) + + +if __name__ == '__main__': + dirt_props = DirtProperties() + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props) + + out_path = Path('debug_out') + model_files = list(natsorted(out_path.rglob('*.zip'))) + this_model = model_files[0] + + model = PPO.load(this_model) + evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, + render=True) + print(evaluation_result) + + env.close()