From 8ce92d5db46b06a77bc783f4be3e2b42a17ed991 Mon Sep 17 00:00:00 2001
From: steffen-illium <steffen.illium@ifi.lmu.de>
Date: Fri, 4 Jun 2021 17:16:45 +0200
Subject: [PATCH] h, w =  fixed

---
 environments/factory/base_factory.py   |  6 +++++-
 environments/factory/levels/rooms.txt  | 13 +++++++++++
 environments/factory/renderer.py       |  7 +++---
 environments/factory/simple_factory.py | 30 +++++++++++++++-----------
 main.py                                | 11 +++++-----
 5 files changed, 44 insertions(+), 23 deletions(-)
 create mode 100644 environments/factory/levels/rooms.txt

diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py
index 2918ed7..0256f81 100644
--- a/environments/factory/base_factory.py
+++ b/environments/factory/base_factory.py
@@ -12,7 +12,7 @@ from environments import helpers as h
 
 
 class MovementProperties(NamedTuple):
-    allow_square_movement: bool = False
+    allow_square_movement: bool = True
     allow_diagonal_movement: bool = False
     allow_no_op: bool = False
 
@@ -111,6 +111,10 @@ class StateSlice(Register):
 
 class BaseFactory(gym.Env):
 
+    # def __setattr__(self, key, value):
+    #     if isinstance(value, dict):
+
+
     @property
     def action_space(self):
         return spaces.Discrete(self._actions.n)
diff --git a/environments/factory/levels/rooms.txt b/environments/factory/levels/rooms.txt
new file mode 100644
index 0000000..83d2e9c
--- /dev/null
+++ b/environments/factory/levels/rooms.txt
@@ -0,0 +1,13 @@
+###############
+#------#------#
+#---#--#------#
+#--------#----#
+#------#------#
+#------#------#
+###-#######-###
+#----##-------#
+#-----#----#--#
+#-------------#
+#-----#-------#
+#-----#-------#
+###############
\ No newline at end of file
diff --git a/environments/factory/renderer.py b/environments/factory/renderer.py
index b598509..fdb56d2 100644
--- a/environments/factory/renderer.py
+++ b/environments/factory/renderer.py
@@ -26,7 +26,7 @@ class Renderer:
         self.grid_lines = grid_lines
         self.view_radius = view_radius
         pygame.init()
-        self.screen_size = (grid_h*cell_size, grid_w*cell_size)
+        self.screen_size = (grid_w*cell_size, grid_h*cell_size)
         self.screen = pygame.display.set_mode(self.screen_size)
         self.clock = pygame.time.Clock()
         assets = list((Path(__file__).parent / 'assets').rglob('*.png'))
@@ -36,7 +36,7 @@ class Renderer:
     def fill_bg(self):
         self.screen.fill(Renderer.BG_COLOR)
         if self.grid_lines:
-            h, w = self.screen_size
+            w, h = self.screen_size
             for x in range(0, w, self.cell_size):
                 for y in range(0, h, self.cell_size):
                     rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
@@ -81,7 +81,8 @@ class Renderer:
                     shape_surf.set_alpha(64)
                     blits.appendleft(dict(source=shape_surf, dest=visibility_rect))
                 blits.append(bp)
-        for blit in blits: self.screen.blit(**blit)
+        for blit in blits:
+            self.screen.blit(**blit)
         pygame.display.flip()
         self.clock.tick(self.fps)
 
diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py
index 2dde6df..20969f5 100644
--- a/environments/factory/simple_factory.py
+++ b/environments/factory/simple_factory.py
@@ -1,11 +1,12 @@
 from collections import OrderedDict
 from dataclasses import dataclass
+from pathlib import Path
 from typing import List, Union, NamedTuple
 import random
 
 import numpy as np
 
-from environments.factory.base_factory import BaseFactory, AgentState
+from environments.factory.base_factory import BaseFactory, AgentState, MovementProperties
 from environments import helpers as h
 
 from environments.logging.monitor import MonitorCallback
@@ -186,16 +187,19 @@ if __name__ == '__main__':
     render = True
 
     dirt_props = DirtProperties()
-    factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props)
+    move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
+    factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
+                            pomdp_radius=2)
+
     n_actions = factory.action_space.n - 1
-    with MonitorCallback(factory):
-        for epoch in range(100):
-            random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
-            env_state, this_reward, done_bool, _ = factory.reset()
-            for agent_i_action in random_actions:
-                env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
-                if render:
-                    factory.render()
-                if done_bool:
-                    break
-            print(f'Factory run {epoch} done, reward is:\n    {reward}')
+
+    for epoch in range(100):
+        random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
+        env_state, this_reward, done_bool, _ = factory.reset()
+        for agent_i_action in random_actions:
+            env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
+            if render:
+                factory.render()
+            if done_bool:
+                break
+        print(f'Factory run {epoch} done, reward is:\n    {reward}')
diff --git a/main.py b/main.py
index d06d030..a9f8093 100644
--- a/main.py
+++ b/main.py
@@ -102,24 +102,23 @@ if __name__ == '__main__':
 
     out_path = None
 
-    # for modeL_type in [PPO, A2C, RegDQN, DQN]:
-    modeL_type = PPO
-    for coef in [0.01, 0.1, 0.25]:
+    for modeL_type in [PPO, A2C, RegDQN, DQN]:
         for seed in range(3):
 
             env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
-                                movement_properties=move_props,
+                                movement_properties=move_props, level='rooms',
                                 omit_agent_slice_in_obs=True)
             env.save_params(Path('debug_out', 'yaml.txt'))
 
             # env = FrameStack(env, 4)
 
-            model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
+            kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
+            model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
 
             out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
 
             # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
-            identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}'
+            identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
             out_path /= identifier
 
             callbacks = CallbackList(