From 8768d9b75fe8181ee9f9b35eb99694df336dbdff Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Thu, 27 May 2021 15:31:45 +0200 Subject: [PATCH] refactored main plus small changes --- environments/factory/renderer.py | 4 +- environments/factory/simple_factory.py | 25 +++--- environments/helpers.py | 3 + environments/logging/plotting.py | 19 ++++- main.py | 103 +++++++++++++++++++++++++ 5 files changed, 141 insertions(+), 13 deletions(-) create mode 100644 main.py diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py index 2e45b59..2ee96d0 100644 --- a/environments/factory/renderer.py +++ b/environments/factory/renderer.py @@ -15,8 +15,8 @@ class Entity: class Renderer: - BG_COLOR = (178, 190, 195)#(99, 110, 114) - WHITE = (223, 230, 233)#(200, 200, 200) + BG_COLOR = (178, 190, 195) # (99, 110, 114) + WHITE = (223, 230, 233) # (200, 200, 200) AGENT_VIEW_COLOR = (9, 132, 227) def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=4, grid_lines=True, view_radius=2): diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 14a2043..273075e 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -17,10 +17,12 @@ DIRT_INDEX = -1 @dataclass class DirtProperties: - clean_amount = 10 - max_spawn_ratio = 0.1 - gain_amount = 0.1 - spawn_frequency = 5 + clean_amount = 2 # How much does the robot clean with one action. + max_spawn_ratio = 0.2 # On max how much tiles does the dirt spawn in percent. + gain_amount = 0.5 # How much dirt does spawn per tile + spawn_frequency = 5 # Spawn Frequency in Steps + max_local_amount = 1 # Max dirt amount per tile. + max_global_amount = 20 # Max dirt amount in the whole environment. class SimpleFactory(BaseFactory): @@ -64,13 +66,15 @@ class SimpleFactory(BaseFactory): self.renderer.render(OrderedDict(dirt=dirt, wall=walls, **agents)) def spawn_dirt(self) -> None: - if not self.state[DIRT_INDEX].sum() > self.max_dirt or not np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > 10: + if not np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] > self._dirt_properties.max_global_amount: free_for_dirt = self.free_cells(excluded_slices=DIRT_INDEX) # randomly distribute dirt across the grid n_dirt_tiles = int(random.uniform(0, self._dirt_properties.max_spawn_ratio) * len(free_for_dirt)) for x, y in free_for_dirt[:n_dirt_tiles]: - self.state[DIRT_INDEX, x, y] += self._dirt_properties.gain_amount + new_value = self.state[DIRT_INDEX, x, y] + self._dirt_properties.gain_amount + self.state[DIRT_INDEX, x, y] = max(new_value, self._dirt_properties.max_local_amount) + else: pass @@ -130,19 +134,20 @@ class SimpleFactory(BaseFactory): f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') if self._is_clean_up_action(agent_state.action): if agent_state.action_valid: - reward += 0.9 + reward += 1 self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') self.monitor.set('dirt_cleaned', 1) else: + reward -= 1 self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' f'at {agent_state.pos}, but was unsucsessfull.') self.monitor.set('failed_cleanup_attempt', 1) - reward -= 0.01 + elif self._is_moving_action(agent_state.action): if agent_state.action_valid: - reward -= 0.2 + reward -= 0.01 else: - reward -= 0.1 + reward -= 0.5 for entity in cols: if entity != self.string_slices["dirt"]: diff --git a/environments/helpers.py b/environments/helpers.py index be3d274..2f47848 100644 --- a/environments/helpers.py +++ b/environments/helpers.py @@ -64,6 +64,9 @@ def check_agent_move(state, dim, action): or y_new >= agent_slice.shape[0] ) + # Check for collision with level walls + valid = valid and not state[LEVEL_IDX][x_new, y_new] + return (x, y), (x_new, y_new), valid diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 57fb922..6a40561 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -3,6 +3,23 @@ import seaborn as sns from matplotlib import pyplot as plt +PALETTE = 10 * ( + "#377eb8", + "#4daf4a", + "#984ea3", + "#e41a1c", + "#ff7f00", + "#a65628", + "#f781bf", + "#888888", + "#a6cee3", + "#b2df8a", + "#cab2d6", + "#fb9a99", + "#fdbf6f", +) + + def plot(filepath, ext='png', tag='monitor', **kwargs): plt.rcParams.update(kwargs) @@ -18,7 +35,7 @@ def prepare_plot(filepath, results_df, ext='png', tag=''): _ = sns.lineplot(data=results_df, ci='sd', x='step') # %% - sns.set_theme(palette='husl', style='whitegrid') + sns.set_theme(palette=PALETTE, style='whitegrid') font_size = 16 tex_fonts = { # Use LaTeX to write all text diff --git a/main.py b/main.py new file mode 100644 index 0000000..f2215cb --- /dev/null +++ b/main.py @@ -0,0 +1,103 @@ +import pickle +import warnings +from typing import Union +from os import PathLike +from pathlib import Path +import time +import pandas as pd + +from stable_baselines3.common.callbacks import CallbackList + +from environments.factory.simple_factory import DirtProperties, SimpleFactory +from environments.logging.monitor import MonitorCallback +from environments.logging.plotting import prepare_plot +from environments.logging.training import TraningMonitor + +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) + + +def combine_runs(run_path: Union[str, PathLike]): + run_path = Path(run_path) + df_list = list() + for run, monitor_file in enumerate(run_path.rglob('monitor_*.pick')): + with monitor_file.open('rb') as f: + monitor_list = pickle.load(f) + + for m_idx in range(len(monitor_list)): + monitor_list[m_idx]['episode'] = str(m_idx) + monitor_list[m_idx]['run'] = str(run) + + df = pd.concat(monitor_list, ignore_index=True) + df['train_step'] = range(df.shape[0]) + + df = df.fillna(0) + + #for column in list(df.columns): + # if column not in ['episode', 'run', 'step', 'train_step']: + # if 'clean' in column or '_vs_' in column: + # df[f'{column}_sum_roll'] = df[column].rolling(window=50, min_periods=1).sum() + # else: + # df[f'{column}_mean_roll'] = df[column].rolling(window=50, min_periods=1).mean() + + df_list.append(df) + df = pd.concat(df_list, ignore_index=True) + df = df.fillna(0) + + df_group = df.groupby(['episode', 'run']).aggregate({col: 'mean' if col in ['dirt_amount', + 'dirty_tiles'] else 'sum' + for col in df.columns if col not in ['episode', 'run'] + }).reset_index() + + import seaborn as sns + from matplotlib import pyplot as plt + df_melted = df_group.melt(id_vars=['train_step', 'run'], + value_vars=['agent_0_vs_level', 'dirt_amount', + 'dirty_tiles', 'step_reward', + 'failed_cleanup_attempt', + 'dirt_cleaned'], var_name="Variable", + value_name="Score") + + sns.lineplot(data=df_melted, x='train_step', y='Score', hue='Variable', ci='sd') + plt.show() + + prepare_plot(filepath=run_path / f'{run_path.name}_monitor_out_combined', + results_df=df.filter(regex=(".+_roll|(step)$")), tag='monitor') + print('Plotting done.') + + +if __name__ == '__main__': + + # combine_runs('debug_out/PPO_1622113195') + # exit() + + from stable_baselines3 import DQN, PPO + + dirt_props = DirtProperties() + time_stamp = int(time.time()) + + out_path = None + + for seed in range(5): + + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props) + model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed) + + out_path = Path('../debug_out') / f'{model.__class__.__name__}_{time_stamp}' + + identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + out_path /= identifier + + callbacks = CallbackList( + [TraningMonitor(out_path / f'train_logging_{identifier}.csv'), + MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] + ) + + model.learn(total_timesteps=int(5e5), callback=callbacks) + + save_path = out_path / f'model_{identifier}.zip' + save_path.parent.mkdir(parents=True, exist_ok=True) + model.save(save_path) + + if out_path: + combine_runs(out_path)