From fcb765f4470cab3d7669126052290774f9f83f73 Mon Sep 17 00:00:00 2001
From: Steffen Illium <steffen.illium@ifi.lmu.de>
Date: Thu, 11 Nov 2021 18:42:48 +0100
Subject: [PATCH] Factory is now Battery Powered

---
 environments/factory/__init__.py           |   2 +-
 environments/factory/base/base_factory.py  |   2 +-
 environments/factory/base/objects.py       |   4 +-
 environments/factory/combined_factories.py |  60 +++++
 environments/factory/factory_battery.py    | 275 +++++++++++++++++++++
 environments/factory/factory_dirt.py       |   5 +-
 environments/factory/factory_dirt_item.py  |   7 -
 environments/factory/factory_item.py       |   2 +-
 environments/helpers.py                    |   5 +
 reload_agent.py                            |   2 +-
 studies/e_1.py                             |   4 +-
 11 files changed, 352 insertions(+), 16 deletions(-)
 create mode 100644 environments/factory/combined_factories.py
 create mode 100644 environments/factory/factory_battery.py
 delete mode 100644 environments/factory/factory_dirt_item.py

diff --git a/environments/factory/__init__.py b/environments/factory/__init__.py
index edde920..9e57a7c 100644
--- a/environments/factory/__init__.py
+++ b/environments/factory/__init__.py
@@ -1,7 +1,7 @@
 def make(env_str, n_agents=1, pomdp_r=2, max_steps=400):
     import yaml
     from pathlib import Path
-    from environments.factory.factory_dirt_item import DirtItemFactory
+    from environments.factory.combined_factories import DirtItemFactory
     from environments.factory.factory_item import ItemFactory, ItemProperties
     from environments.factory.factory_dirt import DirtProperties, DirtFactory
     from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py
index f788ebe..9106abd 100644
--- a/environments/factory/base/base_factory.py
+++ b/environments/factory/base/base_factory.py
@@ -405,7 +405,7 @@ class BaseFactory(gym.Env):
         y0, y1 = max(0, agent.y - r), min(agent.y + r + 1, self._level_shape[1])
         # Other Agent Obs = oobs
         oobs = obs_to_be_padded[:, x0:x1, y0:y1]
-        if oobs.shape[0:] != (d,) * 2:
+        if oobs.shape[0:] != (d, d):
             if xd := oobs.shape[1] % d:
                 if agent.x > r:
                     x0_pad = 0
diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py
index e470799..adec715 100644
--- a/environments/factory/base/objects.py
+++ b/environments/factory/base/objects.py
@@ -32,7 +32,9 @@ class Object:
         else:
             return self._name
 
-    def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None, is_blocking_light=False, **kwargs):
+    def __init__(self, str_ident: Union[str, None] = None, enum_ident: Union[Enum, None] = None,
+                 is_blocking_light=False, **kwargs):
+
         self._str_ident = str_ident
         self._enum_ident = enum_ident
 
diff --git a/environments/factory/combined_factories.py b/environments/factory/combined_factories.py
new file mode 100644
index 0000000..d565181
--- /dev/null
+++ b/environments/factory/combined_factories.py
@@ -0,0 +1,60 @@
+import random
+
+from environments.factory.factory_battery import BatteryFactory, BatteryProperties
+from environments.factory.factory_dirt import DirtFactory, DirtProperties
+from environments.factory.factory_item import ItemFactory
+
+
+# noinspection PyAbstractClass
+class DirtItemFactory(ItemFactory, DirtFactory):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+# noinspection PyAbstractClass
+class DirtBatteryFactory(DirtFactory, BatteryFactory):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+
+if __name__ == '__main__':
+    from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
+
+    render = True
+
+    dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0)
+
+    obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
+                                      pomdp_r=2, additional_agent_placeholder=None)
+
+    move_props = {'allow_square_movement': True,
+                  'allow_diagonal_movement': False,
+                  'allow_no_op': False}
+
+    factory = DirtBatteryFactory(n_agents=5, done_at_collision=False,
+                                 level_name='rooms', max_steps=400,
+                                 obs_prop=obs_props, parse_doors=True,
+                                 record_episodes=True, verbose=True,
+                                 btry_prop=BatteryProperties(),
+                                 mv_prop=move_props, dirt_prop=dirt_props
+                                 )
+
+    # noinspection DuplicatedCode
+    n_actions = factory.action_space.n - 1
+    _ = factory.observation_space
+
+    for epoch in range(4):
+        random_actions = [[random.randint(0, n_actions) for _
+                           in range(factory.n_agents)] for _
+                          in range(factory.max_steps + 1)]
+        env_state = factory.reset()
+        r = 0
+        for agent_i_action in random_actions:
+            env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
+            r += step_r
+            if render:
+                factory.render()
+            if done_bool:
+                break
+        print(f'Factory run {epoch} done, reward is:\n    {r}')
+pass
diff --git a/environments/factory/factory_battery.py b/environments/factory/factory_battery.py
new file mode 100644
index 0000000..c24bf99
--- /dev/null
+++ b/environments/factory/factory_battery.py
@@ -0,0 +1,275 @@
+from typing import Union, NamedTuple, Dict
+
+import numpy as np
+
+from environments.factory.base.base_factory import BaseFactory
+from environments.factory.base.objects import Agent, Action, Entity
+from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
+from environments.factory.renderer import RenderEntity
+from environments.helpers import Constants as c
+
+from environments import helpers as h
+
+
+CHARGE_ACTION = h.EnvActions.CHARGE
+ITEM_DROP_OFF = 1
+
+
+class BatteryProperties(NamedTuple):
+    initial_charge: float = 0.8             #
+    charge_rate: float = 0.4                #
+    charge_locations: int = 20               #
+    per_action_costs: Union[dict, float] = 0.02
+    done_when_discharged = False
+    multi_charge: bool = False
+
+
+class Battery(object):
+
+    @property
+    def is_discharged(self):
+        return self.charge_level == 0
+
+    @property
+    def is_blocking_light(self):
+        return False
+
+    @property
+    def can_collide(self):
+        return False
+
+    @property
+    def name(self):
+        return f'{self.__class__.__name__}({self.agent.name})'
+
+    def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
+        super().__init__()
+        self.agent = agent
+        self._pomdp_r = pomdp_r
+        self._level_shape = level_shape
+        if self._pomdp_r:
+            self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
+        else:
+            self._array = np.zeros((1, *self._level_shape))
+        self.charge_level = initial_charge_level
+
+    def as_array(self):
+        self._array[:] = c.FREE_CELL.value
+        self._array[0, 0] = self.charge_level
+        return self._array
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
+
+    def charge(self, amount) -> c:
+        if self.charge_level < 1:
+            # noinspection PyTypeChecker
+            self.charge_level = min(1, amount + self.charge_level)
+            return c.VALID
+        else:
+            return c.NOT_VALID
+
+    def decharge(self, amount) -> c:
+        if self.charge_level != 0:
+            # noinspection PyTypeChecker
+            self.charge_level = max(0, amount + self.charge_level)
+            return c.VALID
+        else:
+            return  c.NOT_VALID
+
+    def belongs_to_entity(self, entity):
+        return self.agent == entity
+
+    def summarize_state(self, **kwargs):
+        attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
+        attr_dict.update(dict(name=self.name))
+        return attr_dict
+
+
+class BatteriesRegister(ObjectRegister):
+
+    _accepted_objects = Battery
+    is_blocking_light = False
+    can_be_shadowed = False
+    hide_from_obs_builder = True
+
+    def __init__(self, *args, **kwargs):
+        super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
+        self.is_observable = True
+
+    def as_array(self):
+        # self._array[:] = c.FREE_CELL.value
+        for inv_idx, battery in enumerate(self):
+            self._array[inv_idx] = battery.as_array()
+        return self._array
+
+    def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
+        inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
+                                              initial_charge_level)
+                       for _, agent in enumerate(agents)]
+        self.register_additional_items(inventories)
+
+    def idx_by_entity(self, entity):
+        try:
+            return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
+        except StopIteration:
+            return None
+
+    def by_entity(self, entity):
+        try:
+            return next((bat for bat in self if bat.belongs_to_entity(entity)))
+        except StopIteration:
+            return None
+
+    def summarize_states(self, n_steps=None):
+        # as dict with additional nesting
+        # return dict(items=super(Inventories, self).summarize_states())
+        return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
+
+
+class ChargePod(Entity):
+
+    @property
+    def can_collide(self):
+        return False
+
+    @property
+    def encoding(self):
+        return ITEM_DROP_OFF
+
+    def __init__(self, *args, charge_rate: float = 0.4,
+                 multi_charge: bool = False, **kwargs):
+        super(ChargePod, self).__init__(*args, **kwargs)
+        self.charge_rate = charge_rate
+        self.multi_charge = multi_charge
+
+    def charge_battery(self, battery: Battery):
+        if battery.charge_level == 1.0:
+            return c.NOT_VALID
+        if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
+            return c.NOT_VALID
+        battery.charge(self.charge_rate)
+        return c.VALID
+
+    def summarize_state(self, n_steps=None) -> dict:
+        if n_steps == h.STEPS_START:
+            summary = super().summarize_state(n_steps=n_steps)
+            return summary
+
+
+class ChargePods(EntityObjectRegister):
+
+    _accepted_objects = ChargePod
+
+    def as_array(self):
+        self._array[:] = c.FREE_CELL.value
+        for item in self:
+            if item.pos != c.NO_POS.value:
+                self._array[0, item.x, item.y] = item.encoding
+        return self._array
+
+    def __repr__(self):
+        super(ChargePods, self).__repr__()
+
+
+class BatteryFactory(BaseFactory):
+
+    def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
+        if isinstance(btry_prop, dict):
+            btry_prop = BatteryProperties(**btry_prop)
+        self.btry_prop = btry_prop
+        super().__init__(*args, **kwargs)
+
+    @property
+    def additional_entities(self):
+        super_entities = super().additional_entities
+
+        empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
+        charge_pods = ChargePods.from_tiles(
+            empty_tiles, self._level_shape,
+            entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
+                               multi_charge=self.btry_prop.multi_charge)
+        )
+
+        batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
+                                      )
+        batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
+        super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
+        return super_entities
+
+    def do_additional_step(self) -> dict:
+        info_dict = super(BatteryFactory, self).do_additional_step()
+
+        # Decharge
+        batteries = self[c.BATTERIES]
+
+        for agent in self[c.AGENT]:
+            if isinstance(self.btry_prop.per_action_costs, dict):
+                energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
+            else:
+                energy_consumption = self.btry_prop.per_action_costs
+
+            batteries.by_entity(agent).decharge(energy_consumption)
+
+        return info_dict
+
+    def do_charge(self, agent) -> c:
+        if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
+            return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
+        else:
+            return c.NOT_VALID
+
+    def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
+        valid = super().do_additional_actions(agent, action)
+        if valid is None:
+            if action == CHARGE_ACTION:
+                valid = self.do_charge(agent)
+                return valid
+            else:
+                return None
+        else:
+            return valid
+        pass
+
+    def do_additional_reset(self) -> None:
+        # There is Nothing to reset.
+        pass
+
+    def check_additional_done(self) -> bool:
+        super_done = super(BatteryFactory, self).check_additional_done()
+        if super_done:
+            return super_done
+        else:
+            return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
+        pass
+
+    def calculate_additional_reward(self, agent: Agent) -> (int, dict):
+        reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
+        if h.EnvActions.CHARGE == agent.temp_action:
+            if agent.temp_valid:
+                charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
+                info_dict.update({f'{agent.name}_charge': 1})
+                info_dict.update(agent_charged=1)
+                self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
+                reward += 0.1
+            else:
+                self[c.DROP_OFF].by_pos(agent.pos)
+                info_dict.update({f'{agent.name}_failed_charge': 1})
+                info_dict.update(failed_charge=1)
+                self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
+                reward -= 0.1
+
+        if self[c.BATTERIES].by_entity(agent).is_discharged:
+            info_dict.update({f'{agent.name}_discharged': 1})
+            reward -= 1
+        else:
+            info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
+        return reward, info_dict
+
+    def render_additional_assets(self):
+        # noinspection PyUnresolvedReferences
+        additional_assets = super().render_additional_assets()
+        charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
+        additional_assets.extend(charge_pods)
+        return additional_assets
+
diff --git a/environments/factory/factory_dirt.py b/environments/factory/factory_dirt.py
index 32cc80e..67230f4 100644
--- a/environments/factory/factory_dirt.py
+++ b/environments/factory/factory_dirt.py
@@ -261,13 +261,14 @@ if __name__ == '__main__':
 
     dirt_props = DirtProperties(1, 0.05, 0.1, 3, 1, 20, 0)
 
-    obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True, pomdp_r=2, additional_agent_placeholder=None)
+    obs_props = ObservationProperties(render_agents=ARO.COMBINED, omit_agent_self=True,
+                                      pomdp_r=15, additional_agent_placeholder=None)
 
     move_props = {'allow_square_movement': True,
                   'allow_diagonal_movement': False,
                   'allow_no_op': False}
 
-    factory = DirtFactory(n_agents=3, done_at_collision=False,
+    factory = DirtFactory(n_agents=5, done_at_collision=False,
                           level_name='rooms', max_steps=400,
                           obs_prop=obs_props, parse_doors=True,
                           record_episodes=True, verbose=True,
diff --git a/environments/factory/factory_dirt_item.py b/environments/factory/factory_dirt_item.py
deleted file mode 100644
index 895cfe2..0000000
--- a/environments/factory/factory_dirt_item.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from environments.factory.factory_dirt import DirtFactory
-from environments.factory.factory_item import ItemFactory
-
-
-class DirtItemFactory(ItemFactory, DirtFactory):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
diff --git a/environments/factory/factory_item.py b/environments/factory/factory_item.py
index 7b135af..c67d59e 100644
--- a/environments/factory/factory_item.py
+++ b/environments/factory/factory_item.py
@@ -117,7 +117,7 @@ class Inventories(ObjectRegister):
     can_be_shadowed = False
     hide_from_obs_builder = True
 
-    def __init__(self, *args, pomdp_r=0, **kwargs):
+    def __init__(self, *args, **kwargs):
         super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
         self.is_observable = True
 
diff --git a/environments/helpers.py b/environments/helpers.py
index 119b860..641866e 100644
--- a/environments/helpers.py
+++ b/environments/helpers.py
@@ -49,6 +49,10 @@ class Constants(Enum):
     INVENTORY           = 'Inventory'
     DROP_OFF            = 'Drop_Off'
 
+    # Battery Env
+    CHARGE_POD          = 'Charge_Pod'
+    BATTERIES           = 'BATTERIES'
+
     def __bool__(self):
         if 'not_' in self.value:
             return False
@@ -84,6 +88,7 @@ class EnvActions(Enum):
     USE_DOOR    = 'use_door'
     CLEAN_UP    = 'clean_up'
     ITEM_ACTION = 'item_action'
+    CHARGE      = 'charge'
 
 
 m = MovingAction
diff --git a/reload_agent.py b/reload_agent.py
index 5849469..be2387f 100644
--- a/reload_agent.py
+++ b/reload_agent.py
@@ -7,7 +7,7 @@ import yaml
 from environments import helpers as h
 from environments.helpers import Constants as c
 from environments.factory.factory_dirt import DirtFactory
-from environments.factory.factory_dirt_item import DirtItemFactory
+from environments.factory.combined_factories import DirtItemFactory
 from environments.logging.recorder import RecorderCallback
 
 warnings.filterwarnings('ignore', category=FutureWarning)
diff --git a/studies/e_1.py b/studies/e_1.py
index 3370204..07b8321 100644
--- a/studies/e_1.py
+++ b/studies/e_1.py
@@ -23,7 +23,7 @@ from stable_baselines3.common.vec_env import SubprocVecEnv
 
 from environments import helpers as h
 from environments.factory.factory_dirt import DirtProperties, DirtFactory
-from environments.factory.factory_dirt_item import DirtItemFactory
+from environments.factory.combined_factories import DirtItemFactory
 from environments.factory.factory_item import ItemProperties, ItemFactory
 from environments.logging.monitor import MonitorCallback
 from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
@@ -300,7 +300,7 @@ if __name__ == '__main__':
 
     # Train starts here ############################################################
     # Build Major Loop  parameters, parameter versions, Env Classes and models
-    if True:
+    if False:
         for obs_mode in observation_modes.keys():
             for env_name in env_names:
                 for model_cls in [h.MODEL_MAP['A2C']]: