From 83d77df216667db173a97c54bf06a5bfbe58afe1 Mon Sep 17 00:00:00 2001
From: steffen-illium <steffen.illium@ifi.lmu.de>
Date: Fri, 14 May 2021 15:21:31 +0200
Subject: [PATCH] Monitor and Agent State Merge

---
 environments/factory/base_factory.py          |  2 +-
 .../factory/simple_factory_getting_dirty.py   | 25 +++++++++++++------
 2 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py
index 2553245..b33c628 100644
--- a/environments/factory/base_factory.py
+++ b/environments/factory/base_factory.py
@@ -90,7 +90,7 @@ class BaseFactory:
         self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
         self.reset()
 
-    def reset(self):
+    def reset(self)  -> (np.ndarray, int, bool, dict):
         self.done = False
         self.steps = 0
         self.cumulative_reward = 0
diff --git a/environments/factory/simple_factory_getting_dirty.py b/environments/factory/simple_factory_getting_dirty.py
index 1062091..4490e76 100644
--- a/environments/factory/simple_factory_getting_dirty.py
+++ b/environments/factory/simple_factory_getting_dirty.py
@@ -60,12 +60,12 @@ class GettingDirty(BaseFactory):
         else:
             raise RuntimeError('This should not happen!!!')
 
-    def reset(self) -> None:
-        # ToDo: When self.reset returns the new states and stuff, use it here!
-        super().reset()  # state, agents, ... =
+    def reset(self) -> (np.ndarray, int, bool, dict):
+        state, r, done, _ = super().reset()  # state, reward, done, info ... =
         dirt_slice = np.zeros((1, *self.state.shape[1:]))
         self.state = np.concatenate((self.state, dirt_slice))  # dirt is now the last slice
         self.spawn_dirt()
+        return self.state, r, self.done, {}
 
     def calculate_reward(self, agent_states: List[AgentState]) -> (int, dict):
         this_step_reward = 0
@@ -85,8 +85,17 @@ if __name__ == '__main__':
     import random
     dirt_props = DirtProperties()
     factory = GettingDirty(n_agents=1, dirt_properties=dirt_props)
-    random_actions = [random.randint(0, 8) for _ in range(2000)]
-    for random_action in random_actions:
-        state, r, done, _ = factory.step(random_action)
-    print(f'Factory run done, reward is:\n    {r}')
-    print(f'The following running stats have been recorded:\n{dict(factory.monitor)}')
+    monitor_list = list()
+    for epoch in range(100):
+        random_actions = [random.randint(0, 7) for _ in range(200)]
+        state, r, done, _ = factory.reset()
+        for action in random_actions:
+            state, r, done, info = factory.step(action)
+        monitor_list.append(factory.monitor)
+        print(f'Factory run done, reward is:\n    {r}')
+    from pathlib import Path
+    import pickle
+    out_path = Path('debug_out')
+    out_path.mkdir(exist_ok=True, parents=True)
+    with (out_path / 'monitor.pick').open('rb') as f:
+        pickle.dump(monitor_list, f)