From e7d31aa27250f58306db51a4b87b8b16d128740f Mon Sep 17 00:00:00 2001 From: steffen-illium <steffen.illium@ifi.lmu.de> Date: Thu, 20 May 2021 09:49:08 +0200 Subject: [PATCH] Logging Monitor Callback --- environments/factory/base_factory.py | 15 +++++--- environments/factory/simple_factory.py | 53 +++++++++++++++++--------- environments/logging/monitor.py | 10 ++--- environments/logging/training.py | 35 +++++++++++++++++ 4 files changed, 83 insertions(+), 30 deletions(-) create mode 100644 environments/logging/training.py diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index b413164..89dae4d 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -56,6 +56,7 @@ class BaseFactory(gym.Env): self.allow_vertical_movement = True self.allow_horizontal_movement = True self.allow_no_OP = True + self.done_at_collision = True self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions() self.level = h.one_hot_level( h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') @@ -96,7 +97,7 @@ class BaseFactory(gym.Env): self.steps += 1 # Move this in a seperate function? - states = list() + agent_states = list() for agent_i, action in enumerate(actions): agent_i_state = AgentState(agent_i, action) if self._is_moving_action(action): @@ -107,13 +108,15 @@ class BaseFactory(gym.Env): pos, valid = self.additional_actions(agent_i, action) # Update state accordingly agent_i_state.update(pos=pos, action_valid=valid) - states.append(agent_i_state) + agent_states.append(agent_i_state) - for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])): - states[i].update(collision_vector=collision_vec) + for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])): + agent_states[i].update(collision_vector=collision_vec) + if self.done_at_collision and collision_vec.any(): + self.done = True - self.agent_states = states - reward, info = self.calculate_reward(states) + self.agent_states = agent_states + reward, info = self.calculate_reward(agent_states) if self.steps >= self.max_steps: self.done = True diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 13124a3..1d2a2e0 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -17,7 +17,7 @@ DIRT_INDEX = -1 @dataclass class DirtProperties: - clean_amount = 0.25 + clean_amount = 10 max_spawn_ratio = 0.1 gain_amount = 0.1 spawn_frequency = 5 @@ -31,8 +31,9 @@ class SimpleFactory(BaseFactory): def _is_clean_up_action(self, action): return self.action_space.n - 1 == action - def __init__(self, *args, dirt_properties: DirtProperties, **kwargs): + def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs): self._dirt_properties = dirt_properties + self.verbose = verbose super(SimpleFactory, self).__init__(*args, **kwargs) self.slice_strings.update({self.state.shape[0]-1: 'dirt'}) self.renderer = None # expensive - dont use it when not required ! @@ -81,6 +82,9 @@ class SimpleFactory(BaseFactory): return pos, cleanup_was_sucessfull def step(self, actions): + if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL: + print(f'fAgent placed on wall!!!!, step is :{self.steps}') + raise Exception('Agent placed on wall!!!!') _, _, _, info = super(SimpleFactory, self).step(actions) if not self.next_dirt_spawn: self.spawn_dirt() @@ -94,12 +98,6 @@ class SimpleFactory(BaseFactory): if self._is_clean_up_action(action): agent_i_pos = self.agent_i_position(agent_i) _, valid = self.clean_up(agent_i_pos) - if valid: - print(f'Agent {agent_i} did just clean up some dirt at {agent_i_pos}.') - self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) - else: - print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.') - self.monitor.add('failed_cleanup_attempt', 1) return agent_i_pos, valid else: raise RuntimeError('This should not happen!!!') @@ -120,25 +118,44 @@ class SimpleFactory(BaseFactory): dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX])) try: - this_step_reward = (dirty_tiles / current_dirt_amount) + # penalty = current_dirt_amount + penalty = 0 except (ZeroDivisionError, RuntimeWarning): - this_step_reward = 0 - + penalty = 0 + inforcements = 0 for agent_state in agent_states: - collisions = agent_state.collisions - print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' - f'{[self.slice_strings[entity] for entity in collisions if entity != self.string_slices["dirt"]]}') - if self._is_clean_up_action(agent_state.action) and agent_state.action_valid: - this_step_reward += 1 + cols = agent_state.collisions + self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with ' + f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}') + if self._is_clean_up_action(agent_state.action): + if agent_state.action_valid: + inforcements += 10 + self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.') + self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount) + else: + self.print(f'Agent {agent_state.i} just tried to clean up some dirt ' + f'at {agent_state.pos}, but was unsucsessfull.') + self.monitor.add('failed_cleanup_attempt', 1) + elif self._is_moving_action(agent_state.action): + if not agent_state.action_valid: + penalty += 10 + else: + inforcements += 1 - for entity in collisions: + for entity in cols: if entity != self.string_slices["dirt"]: self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1) + + this_step_reward = max(0, inforcements-penalty) self.monitor.set('dirt_amount', current_dirt_amount) self.monitor.set('dirty_tiles', dirty_tiles) - print(f"reward is {this_step_reward}") + self.print(f"reward is {this_step_reward}") return this_step_reward, {} + def print(self, string): + if self.verbose: + print(string) + if __name__ == '__main__': render = True diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py index f2842f2..5cfa59d 100644 --- a/environments/logging/monitor.py +++ b/environments/logging/monitor.py @@ -56,12 +56,10 @@ class FactoryMonitor: class MonitorCallback(BaseCallback): - def __init__(self, env, outpath='debug_out', filename='monitor'): + def __init__(self, env, filepath=Path('debug_out/monitor.pick')): super(MonitorCallback, self).__init__() - self._outpath = Path(outpath) - self._filename = filename + self.filepath = Path(filepath) self._monitor_list = list() - self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick' self.env = env self.started = False self.closed = False @@ -84,7 +82,7 @@ class MonitorCallback(BaseCallback): if self.started: pass else: - self.out_file.parent.mkdir(exist_ok=True, parents=True) + self.filepath.parent.mkdir(exist_ok=True, parents=True) self.started = True pass @@ -93,7 +91,7 @@ class MonitorCallback(BaseCallback): pass else: # self.out_file.unlink(missing_ok=True) - with self.out_file.open('wb') as f: + with self.filepath.open('wb') as f: pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL) self.closed = True diff --git a/environments/logging/training.py b/environments/logging/training.py new file mode 100644 index 0000000..ce7aa50 --- /dev/null +++ b/environments/logging/training.py @@ -0,0 +1,35 @@ +from pathlib import Path + +import pandas as pd +from stable_baselines3.common.callbacks import BaseCallback + + +class TraningMonitor(BaseCallback): + + def __init__(self, filepath, flush_interval=None): + super(TraningMonitor, self).__init__() + self.values = dict() + self.filepath = Path(filepath) + self.flush_interval = flush_interval + pass + + def _on_training_start(self) -> None: + self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1) + + def _flush(self): + df = pd.DataFrame.from_dict(self.values) + if not self.filepath.exists(): + df.to_csv(self.filepath, mode='wb', header=True) + else: + df.to_csv(self.filepath, mode='a', header=False) + self.values = dict() + + def _on_step(self) -> bool: + self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item()) + if self.num_timesteps % self.flush_interval == 0: + self._flush() + return True + + def on_training_end(self) -> None: + self._flush() +