From c1cb7a4ffc061c31efa54a3c04081666ddc68cb4 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Wed, 19 May 2021 18:27:22 +0200 Subject: [PATCH] Stable Baseline Running --- environments/factory/base_factory.py | 12 ++---------- environments/factory/simple_factory.py | 5 +++-- environments/logging/monitor.py | 11 ++++++++++- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 6d8414a..b413164 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,4 +1,3 @@ -import abc from typing import List, Union, Iterable import gym @@ -43,10 +42,6 @@ class BaseFactory(gym.Env): def observation_space(self): return spaces.Box(low=-1, high=1, shape=self.state.shape, dtype=np.float32) - @property - def monitor_as_df_list(self): - return [x.to_pd_dataframe() for x in self._monitor_list] - @property def movement_actions(self): return (int(self.allow_vertical_movement) + int(self.allow_horizontal_movement)) * 4 @@ -61,7 +56,6 @@ class BaseFactory(gym.Env): self.allow_vertical_movement = True self.allow_horizontal_movement = True self.allow_no_OP = True - self._monitor_list = list() self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') @@ -77,7 +71,6 @@ class BaseFactory(gym.Env): self.steps = 0 self.cumulative_reward = 0 self.monitor = FactoryMonitor(self) - self._monitor_list.append(self.monitor) self.agent_states = [] # Agent placement ... agents = np.zeros((self.n_agents, *self.level.shape), dtype=np.int8) @@ -92,7 +85,6 @@ class BaseFactory(gym.Env): # state.shape = level, agent 1,..., agent n, self.state = np.concatenate((np.expand_dims(self.level, axis=0), agents), axis=0) # Returns State - return self.state def additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): @@ -122,11 +114,11 @@ class BaseFactory(gym.Env): self.agent_states = states reward, info = self.calculate_reward(states) - self.cumulative_reward += reward if self.steps >= self.max_steps: self.done = True - return self.state, self.cumulative_reward, self.done, info + self.monitor.add('step_reward', reward) + return self.state, reward, self.done, info def _is_moving_action(self, action): return action < self.movement_actions diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 3ad70ea..13124a3 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -120,8 +120,8 @@ class SimpleFactory(BaseFactory): dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX])) try: - this_step_reward = -(dirty_tiles / current_dirt_amount) - except ZeroDivisionError: + this_step_reward = (dirty_tiles / current_dirt_amount) + except (ZeroDivisionError, RuntimeWarning): this_step_reward = 0 for agent_state in agent_states: @@ -136,6 +136,7 @@ class SimpleFactory(BaseFactory): self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirty_tiles', dirty_tiles) + print(f"reward is {this_step_reward}") return this_step_reward, {} diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index f8f5858..f2842f2 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -60,17 +60,26 @@ class MonitorCallback(BaseCallback): super(MonitorCallback, self).__init__() self._outpath = Path(outpath) self._filename = filename + self._monitor_list = list() self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick' self.env = env self.started = False self.closed = False + @property + def monitor_as_df_list(self): + return [x.to_pd_dataframe() for x in self._monitor_list] + def __enter__(self): self._on_training_start() def __exit__(self, exc_type, exc_val, exc_tb): self._on_training_end() + def _on_rollout_end(self) -> None: + self._monitor_list.append(self.env.monitor) + pass + def _on_training_start(self) -> None: if self.started: pass @@ -85,7 +94,7 @@ class MonitorCallback(BaseCallback): else: # self.out_file.unlink(missing_ok=True) with self.out_file.open('wb') as f: - pickle.dump(self.env.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) self.closed = True def _on_step(self) -> bool: