From e7d31aa27250f58306db51a4b87b8b16d128740f Mon Sep 17 00:00:00 2001
From: steffen-illium <steffen.illium@ifi.lmu.de>
Date: Thu, 20 May 2021 09:49:08 +0200
Subject: [PATCH] Logging Monitor Callback

---
 environments/factory/base_factory.py   | 15 +++++---
 environments/factory/simple_factory.py | 53 +++++++++++++++++---------
 environments/logging/monitor.py        | 10 ++---
 environments/logging/training.py       | 35 +++++++++++++++++
 4 files changed, 83 insertions(+), 30 deletions(-)
 create mode 100644 environments/logging/training.py

diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py
index b413164..89dae4d 100644
--- a/environments/factory/base_factory.py
+++ b/environments/factory/base_factory.py
@@ -56,6 +56,7 @@ class BaseFactory(gym.Env):
         self.allow_vertical_movement = True
         self.allow_horizontal_movement = True
         self.allow_no_OP = True
+        self.done_at_collision = True
         self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
         self.level = h.one_hot_level(
             h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
@@ -96,7 +97,7 @@ class BaseFactory(gym.Env):
         self.steps += 1
 
         # Move this in a seperate function?
-        states = list()
+        agent_states = list()
         for agent_i, action in enumerate(actions):
             agent_i_state = AgentState(agent_i, action)
             if self._is_moving_action(action):
@@ -107,13 +108,15 @@ class BaseFactory(gym.Env):
                 pos, valid = self.additional_actions(agent_i, action)
             # Update state accordingly
             agent_i_state.update(pos=pos, action_valid=valid)
-            states.append(agent_i_state)
+            agent_states.append(agent_i_state)
 
-        for i, collision_vec in enumerate(self.check_all_collisions(states, self.state.shape[0])):
-            states[i].update(collision_vector=collision_vec)
+        for i, collision_vec in enumerate(self.check_all_collisions(agent_states, self.state.shape[0])):
+            agent_states[i].update(collision_vector=collision_vec)
+            if self.done_at_collision and collision_vec.any():
+                self.done = True
 
-        self.agent_states = states
-        reward, info = self.calculate_reward(states)
+        self.agent_states = agent_states
+        reward, info = self.calculate_reward(agent_states)
 
         if self.steps >= self.max_steps:
             self.done = True
diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py
index 13124a3..1d2a2e0 100644
--- a/environments/factory/simple_factory.py
+++ b/environments/factory/simple_factory.py
@@ -17,7 +17,7 @@ DIRT_INDEX = -1
 
 @dataclass
 class DirtProperties:
-    clean_amount = 0.25
+    clean_amount = 10
     max_spawn_ratio = 0.1
     gain_amount = 0.1
     spawn_frequency = 5
@@ -31,8 +31,9 @@ class SimpleFactory(BaseFactory):
     def _is_clean_up_action(self, action):
         return self.action_space.n - 1 == action
 
-    def __init__(self, *args, dirt_properties: DirtProperties, **kwargs):
+    def __init__(self, *args, dirt_properties: DirtProperties, verbose=False, **kwargs):
         self._dirt_properties = dirt_properties
+        self.verbose = verbose
         super(SimpleFactory, self).__init__(*args, **kwargs)
         self.slice_strings.update({self.state.shape[0]-1: 'dirt'})
         self.renderer = None  # expensive - dont use it when not required !
@@ -81,6 +82,9 @@ class SimpleFactory(BaseFactory):
             return pos, cleanup_was_sucessfull
 
     def step(self, actions):
+        if self.state[h.LEVEL_IDX][self.agent_i_position(0)] == h.IS_OCCUPIED_CELL:
+            print(f'fAgent placed on wall!!!!, step is :{self.steps}')
+            raise Exception('Agent placed on wall!!!!')
         _, _, _, info = super(SimpleFactory, self).step(actions)
         if not self.next_dirt_spawn:
             self.spawn_dirt()
@@ -94,12 +98,6 @@ class SimpleFactory(BaseFactory):
             if self._is_clean_up_action(action):
                 agent_i_pos = self.agent_i_position(agent_i)
                 _, valid = self.clean_up(agent_i_pos)
-                if valid:
-                    print(f'Agent {agent_i} did just clean up some dirt at {agent_i_pos}.')
-                    self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
-                else:
-                    print(f'Agent {agent_i} just tried to clean up some dirt at {agent_i_pos}, but was unsucsessfull.')
-                    self.monitor.add('failed_cleanup_attempt', 1)
                 return agent_i_pos, valid
             else:
                 raise RuntimeError('This should not happen!!!')
@@ -120,25 +118,44 @@ class SimpleFactory(BaseFactory):
         dirty_tiles = len(np.nonzero(self.state[DIRT_INDEX]))
 
         try:
-            this_step_reward = (dirty_tiles / current_dirt_amount)
+            # penalty = current_dirt_amount
+            penalty = 0
         except (ZeroDivisionError, RuntimeWarning):
-            this_step_reward = 0
-
+            penalty = 0
+        inforcements = 0
         for agent_state in agent_states:
-            collisions = agent_state.collisions
-            print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
-                  f'{[self.slice_strings[entity] for entity in collisions if entity != self.string_slices["dirt"]]}')
-            if self._is_clean_up_action(agent_state.action) and agent_state.action_valid:
-                this_step_reward += 1
+            cols = agent_state.collisions
+            self.print(f't = {self.steps}\tAgent {agent_state.i} has collisions with '
+                       f'{[self.slice_strings[entity] for entity in cols if entity != self.string_slices["dirt"]]}')
+            if self._is_clean_up_action(agent_state.action):
+                if agent_state.action_valid:
+                    inforcements += 10
+                    self.print(f'Agent {agent_state.i} did just clean up some dirt at {agent_state.pos}.')
+                    self.monitor.add('dirt_cleaned', self._dirt_properties.clean_amount)
+                else:
+                    self.print(f'Agent {agent_state.i} just tried to clean up some dirt '
+                               f'at {agent_state.pos}, but was unsucsessfull.')
+                    self.monitor.add('failed_cleanup_attempt', 1)
+            elif self._is_moving_action(agent_state.action):
+                if not agent_state.action_valid:
+                    penalty += 10
+                else:
+                    inforcements += 1
 
-            for entity in collisions:
+            for entity in cols:
                 if entity != self.string_slices["dirt"]:
                     self.monitor.add(f'agent_{agent_state.i}_vs_{self.slice_strings[entity]}', 1)
+
+        this_step_reward = max(0, inforcements-penalty)
         self.monitor.set('dirt_amount', current_dirt_amount)
         self.monitor.set('dirty_tiles', dirty_tiles)
-        print(f"reward is {this_step_reward}")
+        self.print(f"reward is {this_step_reward}")
         return this_step_reward, {}
 
+    def print(self, string):
+        if self.verbose:
+            print(string)
+
 
 if __name__ == '__main__':
     render = True
diff --git a/environments/logging/monitor.py b/environments/logging/monitor.py
index f2842f2..5cfa59d 100644
--- a/environments/logging/monitor.py
+++ b/environments/logging/monitor.py
@@ -56,12 +56,10 @@ class FactoryMonitor:
 
 class MonitorCallback(BaseCallback):
 
-    def __init__(self, env, outpath='debug_out', filename='monitor'):
+    def __init__(self, env, filepath=Path('debug_out/monitor.pick')):
         super(MonitorCallback, self).__init__()
-        self._outpath = Path(outpath)
-        self._filename = filename
+        self.filepath = Path(filepath)
         self._monitor_list = list()
-        self.out_file = self._outpath / f'{self._filename.split(".")[0]}.pick'
         self.env = env
         self.started = False
         self.closed = False
@@ -84,7 +82,7 @@ class MonitorCallback(BaseCallback):
         if self.started:
             pass
         else:
-            self.out_file.parent.mkdir(exist_ok=True, parents=True)
+            self.filepath.parent.mkdir(exist_ok=True, parents=True)
             self.started = True
         pass
 
@@ -93,7 +91,7 @@ class MonitorCallback(BaseCallback):
             pass
         else:
             # self.out_file.unlink(missing_ok=True)
-            with self.out_file.open('wb') as f:
+            with self.filepath.open('wb') as f:
                 pickle.dump(self.monitor_as_df_list, f, protocol=pickle.HIGHEST_PROTOCOL)
             self.closed = True
 
diff --git a/environments/logging/training.py b/environments/logging/training.py
new file mode 100644
index 0000000..ce7aa50
--- /dev/null
+++ b/environments/logging/training.py
@@ -0,0 +1,35 @@
+from pathlib import Path
+
+import pandas as pd
+from stable_baselines3.common.callbacks import BaseCallback
+
+
+class TraningMonitor(BaseCallback):
+
+    def __init__(self, filepath, flush_interval=None):
+        super(TraningMonitor, self).__init__()
+        self.values = dict()
+        self.filepath = Path(filepath)
+        self.flush_interval = flush_interval
+        pass
+
+    def _on_training_start(self) -> None:
+        self.flush_interval = self.flush_interval or (self.locals['total_timesteps'] * 0.1)
+
+    def _flush(self):
+        df = pd.DataFrame.from_dict(self.values)
+        if not self.filepath.exists():
+            df.to_csv(self.filepath, mode='wb', header=True)
+        else:
+            df.to_csv(self.filepath, mode='a', header=False)
+        self.values = dict()
+
+    def _on_step(self) -> bool:
+        self.values[self.num_timesteps] = dict(reward=self.locals['rewards'].item())
+        if self.num_timesteps % self.flush_interval == 0:
+            self._flush()
+        return True
+
+    def on_training_end(self) -> None:
+        self._flush()
+