diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 36972ab..538d4c2 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -191,8 +191,7 @@ class BaseFactory(gym.Env): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): # Returns: Reward, Info - # Set to "raise NotImplementedError" - return 0, {} + raise NotImplementedError def render(self): raise NotImplementedError diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 36198a7..89fcdf1 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -112,7 +112,7 @@ class SimpleFactory(BaseFactory): def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict): # TODO: What reward to use? current_dirt_amount = self.state[DIRT_INDEX].sum() - dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX])) + dirty_tiles = np.argwhere(self.state[DIRT_INDEX] != h.IS_FREE_CELL).shape[0] try: # penalty = current_dirt_amount @@ -128,7 +128,7 @@ class SimpleFactory(BaseFactory): if agent_state.action_valid: reward += 2 self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') - self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) + self.monitor.add('dirt_cleaned', 1) else: self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' f'at {agent_state.pos}, but was unsucsessfull.') diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index fabee50..299ea53 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -4,6 +4,8 @@ from collections import defaultdict from stable_baselines3.common.callbacks import BaseCallback +from environments.logging.plotting import prepare_plot + class FactoryMonitor: @@ -58,11 +60,12 @@ class MonitorCallback(BaseCallback): ext = 'png' - def __init__(self, env, filepath=Path('debug_out/monitor.pick')): + def __init__(self, env, filepath=Path('debug_out/monitor.pick'), plotting=True): super(MonitorCallback, self).__init__() self.filepath = Path(filepath) self._monitor_list = list() self.env = env + self.plotting = plotting self.started = False self.closed = False @@ -91,7 +94,18 @@ class MonitorCallback(BaseCallback): # self.out_file.unlink(missing_ok=True) with self.filepath.open('wb') as f: pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) - self.prepare_plot() + if self.plotting: + print('Monitor files were dumped to disk, now plotting....') + # %% Imports + import pandas as pd + # %% Load MonitorList from Disk + with self.filepath.open('rb') as f: + monitor_list = pickle.load(f) + + result = pd.concat(monitor_list, sort=False) + # result.tail() + prepare_plot(filepath=self.filepath, results_df=result, tag='monitor') + print('Plotting done.') self.closed = True def _on_step(self) -> bool: @@ -99,50 +113,5 @@ class MonitorCallback(BaseCallback): self._monitor_list.append(self.env.monitor) else: pass + return True - def plot(self, **kwargs): - from matplotlib import pyplot as plt - plt.rcParams.update(kwargs) - - plt.tight_layout() - figure = plt.gcf() - plt.show() - figure.savefig(str(self.filepath.parent / f'{self.filepath.stem}_monitor_measures.{self.ext}'), format=self.ext) - - def prepare_plot(self): - # %% Imports - import pandas as pd - import seaborn as sns - - # %% Load MonitorList from Disk - with self.filepath.open('rb') as f: - monitor_list = pickle.load(f) - - result = pd.concat(monitor_list, sort=False) - # result.tail() - - # %% - lineplot = sns.lineplot(data=result) - lineplot.title.title = f'Lineplot Summary of {len(monitor_list)} Episodes' - - # %% - sns.set_theme(palette='husl', style='whitegrid') - font_size = 16 - tex_fonts = { - # Use LaTeX to write all text - "text.usetex": True, - "font.family": "serif", - # Use 10pt font in plots, to match 10pt font in document - "axes.labelsize": font_size, - "font.size": font_size, - # Make the legend/label fonts a little smaller - "legend.fontsize": font_size - 2, - "xtick.labelsize": font_size - 2, - "ytick.labelsize": font_size - 2 - } - - try: - self.plot(**tex_fonts) - except FileNotFoundError: - tex_fonts['text.usetex'] = False - self.plot(**tex_fonts) diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py new file mode 100644 index 0000000..fb4690b --- /dev/null +++ b/environments/logging/plotting.py @@ -0,0 +1,40 @@ +import seaborn as sns + +from matplotlib import pyplot as plt + + +def plot(filepath, ext='png', tag='monitor', **kwargs): + plt.rcParams.update(kwargs) + + plt.tight_layout() + figure = plt.gcf() + plt.show() + figure.savefig(str(filepath.parent / f'{filepath.stem}_{tag}_measures.{ext}'), format=ext) + + +def prepare_plot(filepath, results_df, ext='png', tag=''): + # %% + + _ = sns.lineplot(data=results_df) + + # %% + sns.set_theme(palette='husl', style='whitegrid') + font_size = 16 + tex_fonts = { + # Use LaTeX to write all text + "text.usetex": False, + "font.family": "serif", + # Use 10pt font in plots, to match 10pt font in document + "axes.labelsize": font_size, + "font.size": font_size, + # Make the legend/label fonts a little smaller + "legend.fontsize": font_size - 2, + "xtick.labelsize": font_size - 2, + "ytick.labelsize": font_size - 2 + } + + try: + plot(filepath, ext=ext, tag=tag, **tex_fonts) + except (FileNotFoundError, RuntimeError): + tex_fonts['text.usetex'] = False + plot(filepath, ext=ext, tag=tag, **tex_fonts) diff --git a/environments/logging/training.py b/environments/logging/training.py index ce7aa50..0276851 100644 --- a/environments/logging/training.py +++ b/environments/logging/training.py @@ -1,35 +1,54 @@ +from collections import defaultdict from pathlib import Path +import numpy as np import pandas as pd from stable_baselines3.common.callbacks import BaseCallback +from environments.logging.plotting import prepare_plot + class TraningMonitor(BaseCallback): def __init__(self, filepath, flush_interval=None): super(TraningMonitor, self).__init__() - self.values = dict() + self.values = defaultdict(dict) + self.rewards = defaultdict(lambda: 0) + self.filepath = Path(filepath) self.flush_interval = flush_interval + self.next_flush: int pass def _on_training_start(self) -> None: self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1) + self.next_flush = self.flush_interval def _flush(self): - df = pd.DataFrame.from_dict(self.values) + df = pd.DataFrame.from_dict(self.values, orient='index') if not self.filepath.exists(): df.to_csv(self.filepath, mode='wb', header=True) else: df.to_csv(self.filepath, mode='a', header=False) - self.values = dict() def _on_step(self) -> bool: - self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item()) - if self.num_timesteps % self.flush_interval == 0: + for idx, done in np.ndenumerate(self.locals['dones']): + idx = idx[0] + # self.values[self.num_timesteps].update(**{f'reward_env_{idx}': self.locals['rewards'][idx]}) + self.rewards[idx] += self.locals['rewards'][idx] + if done: + self.values[self.num_timesteps].update(**{f'acc_epispde_r_env_{idx}': self.rewards[idx]}) + self.rewards[idx] = 0 + + if self.num_timesteps >= self.next_flush and self.values: self._flush() + self.values = defaultdict(dict) + + self.next_flush += self.flush_interval return True def on_training_end(self) -> None: self._flush() + self.values = defaultdict(dict) + # prepare_plot()