From 8ba7c418f0a89a3f9c4f93520996216fc0c8a959 Mon Sep 17 00:00:00 2001
From: steffen-illium <steffen.illium@ifi.lmu.de>
Date: Fri, 4 Jun 2021 16:43:34 +0200
Subject: [PATCH] named tuples working

---
 environments/factory/base_factory.py   | 35 +++++++++++++++-----------
 environments/factory/simple_factory.py |  5 ++--
 main.py                                |  6 ++++-
 3 files changed, 27 insertions(+), 19 deletions(-)

diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py
index 850f794..2918ed7 100644
--- a/environments/factory/base_factory.py
+++ b/environments/factory/base_factory.py
@@ -1,6 +1,6 @@
 import pickle
 from pathlib import Path
-from typing import List, Union, Iterable
+from typing import List, Union, Iterable, NamedTuple
 
 import gym
 import numpy as np
@@ -11,6 +11,12 @@ import yaml
 from environments import helpers as h
 
 
+class MovementProperties(NamedTuple):
+    allow_square_movement: bool = False
+    allow_diagonal_movement: bool = False
+    allow_no_op: bool = False
+
+
 class AgentState:
 
     def __init__(self, i: int, action: int):
@@ -78,16 +84,17 @@ class Actions(Register):
     def movement_actions(self):
         return self._movement_actions
 
-    def __init__(self, allow_square_movement=False, allow_diagonal_movement=False, allow_no_op=False):
+    def __init__(self, movement_properties: MovementProperties):
+        self.allow_no_op = movement_properties.allow_no_op
+        self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
+        self.allow_square_movement = movement_properties.allow_square_movement
         # FIXME: There is a bug in helpers because there actions are ints. and the order matters.
-        assert not(allow_square_movement is False and allow_diagonal_movement is True), "There is a bug in helpers!!!"
+        assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
         super(Actions, self).__init__()
-        self.allow_no_op = allow_no_op
-        self.allow_diagonal_movement = allow_diagonal_movement
-        self.allow_square_movement = allow_square_movement
-        if allow_square_movement:
+
+        if self.allow_square_movement:
             self + ['north', 'east', 'south', 'west']
-        if allow_diagonal_movement:
+        if self.allow_diagonal_movement:
             self + ['north-east', 'south-east', 'south-west', 'north-west']
         self._movement_actions = self._register.copy()
         if self.allow_no_op:
@@ -124,20 +131,18 @@ class BaseFactory(gym.Env):
         return self._actions.movement_actions
 
     def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
-                 allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True,
+                 movement_properties: MovementProperties = MovementProperties(),
                  omit_agent_slice_in_obs=False, **kwargs):
-        self.allow_no_op = allow_no_op
-        self.allow_diagonal_movement = allow_diagonal_movement
-        self.allow_square_movement = allow_square_movement
+
+        self.movement_properties = movement_properties
+
         self.n_agents = n_agents
         self.max_steps = max_steps
         self.pomdp_radius = pomdp_radius
         self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
 
         self.done_at_collision = False
-        _actions = Actions(allow_square_movement=self.allow_square_movement,
-                           allow_diagonal_movement=self.allow_diagonal_movement,
-                           allow_no_op=allow_no_op)
+        _actions = Actions(self.movement_properties)
         self._actions = _actions + self.additional_actions
 
         self._level = h.one_hot_level(
diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py
index 8b89116..2dde6df 100644
--- a/environments/factory/simple_factory.py
+++ b/environments/factory/simple_factory.py
@@ -1,6 +1,6 @@
 from collections import OrderedDict
 from dataclasses import dataclass
-from typing import List, Union
+from typing import List, Union, NamedTuple
 import random
 
 import numpy as np
@@ -15,8 +15,7 @@ DIRT_INDEX = -1
 CLEAN_UP_ACTION = 'clean_up'
 
 
-@dataclass
-class DirtProperties:
+class DirtProperties(NamedTuple):
     clean_amount: int = 2            # How much does the robot clean with one action.
     max_spawn_ratio: float = 0.2       # On max how much tiles does the dirt spawn in percent.
     gain_amount: float = 0.5           # How much dirt does spawn per tile
diff --git a/main.py b/main.py
index 89c9781..d06d030 100644
--- a/main.py
+++ b/main.py
@@ -12,6 +12,7 @@ from gym.wrappers import FrameStack
 from stable_baselines3.common.callbacks import CallbackList
 from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
 
+from environments.factory.base_factory import MovementProperties
 from environments.factory.simple_factory import DirtProperties, SimpleFactory
 from environments.helpers import IGNORED_DF_COLUMNS
 from environments.logging.monitor import MonitorCallback
@@ -94,6 +95,9 @@ if __name__ == '__main__':
     # from sb3_contrib import QRDQN
 
     dirt_props = DirtProperties()
+    move_props = MovementProperties(allow_diagonal_movement=False,
+                                    allow_square_movement=True,
+                                    allow_no_op=False)
     time_stamp = int(time.time())
 
     out_path = None
@@ -104,7 +108,7 @@ if __name__ == '__main__':
         for seed in range(3):
 
             env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
-                                allow_diagonal_movement=True, allow_no_op=False, verbose=False,
+                                movement_properties=move_props,
                                 omit_agent_slice_in_obs=True)
             env.save_params(Path('debug_out', 'yaml.txt'))