From 3d18fe98390e0b618f1c589e3a4e1a37649cd3d9 Mon Sep 17 00:00:00 2001
From: Chanumask <joelfriedrich@gmx.de>
Date: Fri, 10 Nov 2023 10:43:11 +0100
Subject: [PATCH] added test hooks (like rules)

---
 marl_factory_grid/environment/tests.py     | 42 +++++++++++
 marl_factory_grid/testing/test_config.yaml | 81 ++++++++++++++++++++++
 marl_factory_grid/testing/test_run.py      | 32 +++++++++
 3 files changed, 155 insertions(+)
 create mode 100644 marl_factory_grid/environment/tests.py
 create mode 100644 marl_factory_grid/testing/test_config.yaml
 create mode 100644 marl_factory_grid/testing/test_run.py

diff --git a/marl_factory_grid/environment/tests.py b/marl_factory_grid/environment/tests.py
new file mode 100644
index 0000000..ee615c3
--- /dev/null
+++ b/marl_factory_grid/environment/tests.py
@@ -0,0 +1,42 @@
+import abc
+from typing import List
+
+from marl_factory_grid.utils.results import TickResult, DoneResult
+
+
+class Test(abc.ABC):
+
+    @property
+    def name(self):
+        return self.__class__.__name__
+
+    def __init__(self):
+        pass
+
+    def __repr__(self):
+        return f'{self.name}'
+
+    def on_init(self, state, lvl_map):
+        return []
+
+    def on_reset(self):
+        return []
+
+    def tick_pre_step(self, state) -> List[TickResult]:
+        return []
+
+    def tick_step(self, state) -> List[TickResult]:
+        return []
+
+    def tick_post_step(self, state) -> List[TickResult]:
+        return []
+
+    def on_check_done(self, state) -> List[DoneResult]:
+        return []
+
+
+class FirstTest(Test):
+
+    def __init__(self):
+        super().__init__()
+        pass
diff --git a/marl_factory_grid/testing/test_config.yaml b/marl_factory_grid/testing/test_config.yaml
new file mode 100644
index 0000000..060ead2
--- /dev/null
+++ b/marl_factory_grid/testing/test_config.yaml
@@ -0,0 +1,81 @@
+Agents:
+  Wolfgang:
+    Actions:
+    - Noop
+    - BtryCharge
+    - CleanUp
+    - DestAction
+    - DoorUse
+    - ItemAction
+    - Move8
+    Observations:
+    - Combined:
+      - Other
+      - Walls
+    - GlobalPosition
+    - Battery
+    - ChargePods
+    - DirtPiles
+    - Destinations
+    - Doors
+    - Items
+    - Inventory
+    - DropOffLocations
+    - Maintainers
+Entities:
+  Batteries:
+    initial_charge: 0.8
+    per_action_costs: 0.02
+  ChargePods: {}
+  Destinations: {}
+  DirtPiles:
+    clean_amount: 1
+    dirt_spawn_r_var: 0.1
+    initial_amount: 2
+    initial_dirt_ratio: 0.05
+    max_global_amount: 20
+    max_local_amount: 5
+  Doors: {}
+  DropOffLocations: {}
+  GlobalPositions: {}
+  Inventories: {}
+  Items: {}
+  Machines: {}
+  Maintainers: {}
+  Zones: {}
+
+General:
+  env_seed: 69
+  individual_rewards: true
+  level_name: large
+  pomdp_r: 3
+  verbose: false
+  tests: true
+
+Rules:
+  SpawnAgents: {}
+  DoneAtBatteryDischarge: {}
+  Collision:
+    done_at_collisions: false
+  AssignGlobalPositions: {}
+  DoneAtDestinationReachAny: {}
+  DestinationReachReward: {}
+  SpawnDestinations:
+    n_dests: 1
+    spawn_mode: GROUPED
+  DoneOnAllDirtCleaned: {}
+  SpawnDirt:
+    spawn_freq: 15
+  EntitiesSmearDirtOnMove:
+    smear_ratio: 0.2
+  DoorAutoClose:
+    close_frequency: 10
+  ItemRules:
+    max_dropoff_storage_size: 0
+    n_items: 5
+    n_locations: 5
+    spawn_frequency: 15
+  MaxStepsReached:
+    max_steps: 10
+#  AgentSingleZonePlacement:
+#    n_zones: 4
diff --git a/marl_factory_grid/testing/test_run.py b/marl_factory_grid/testing/test_run.py
new file mode 100644
index 0000000..35112ca
--- /dev/null
+++ b/marl_factory_grid/testing/test_run.py
@@ -0,0 +1,32 @@
+from pathlib import Path
+from random import randint
+
+from tqdm import trange
+
+from marl_factory_grid.environment.factory import Factory
+
+if __name__ == '__main__':
+    # Render at each step?
+    render = True
+
+    # Path to config File
+    path = Path('test_config.yaml')
+
+    # Env Init
+    factory = Factory(path)
+
+    for episode in trange(5):
+        _ = factory.reset()
+        done = False
+        if render:
+            factory.render()
+        action_spaces = factory.action_space
+        while not done:
+            a = [randint(0, x.n - 1) for x in action_spaces]
+            obs_type, _, _, done, info = factory.step(a)
+            if render:
+                factory.render()
+            if done:
+                print(f'Episode {episode} done...')
+                break
+