from typing import Union, NamedTuple import numpy as np from environments.factory.base.base_factory import BaseFactory from environments.factory.base.objects import Agent, Action, Entity from environments.factory.base.registers import EntityObjectRegister, ObjectRegister from environments.factory.base.renderer import RenderEntity from environments.helpers import Constants as c from environments import helpers as h CHARGE_ACTION = h.EnvActions.CHARGE ITEM_DROP_OFF = 1 class BatteryProperties(NamedTuple): initial_charge: float = 0.8 # charge_rate: float = 0.4 # charge_locations: int = 20 # per_action_costs: Union[dict, float] = 0.02 done_when_discharged = False multi_charge: bool = False class Battery(object): @property def is_discharged(self): return self.charge_level == 0 @property def is_blocking_light(self): return False @property def can_collide(self): return False @property def name(self): return f'{self.__class__.__name__}({self.agent.name})' def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float): super().__init__() self.agent = agent self._pomdp_r = pomdp_r self._level_shape = level_shape if self._pomdp_r: self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1)) else: self._array = np.zeros((1, *self._level_shape)) self.charge_level = initial_charge_level def as_array(self): self._array[:] = c.FREE_CELL.value self._array[0, 0] = self.charge_level return self._array def __repr__(self): return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})' def charge(self, amount) -> c: if self.charge_level < 1: # noinspection PyTypeChecker self.charge_level = min(1, amount + self.charge_level) return c.VALID else: return c.NOT_VALID def decharge(self, amount) -> c: if self.charge_level != 0: # noinspection PyTypeChecker self.charge_level = max(0, amount + self.charge_level) return c.VALID else: return c.NOT_VALID def belongs_to_entity(self, entity): return self.agent == entity def summarize_state(self, **kwargs): attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} attr_dict.update(dict(name=self.name)) return attr_dict class BatteriesRegister(ObjectRegister): _accepted_objects = Battery is_blocking_light = False can_be_shadowed = False hide_from_obs_builder = True def __init__(self, *args, **kwargs): super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs) self.is_observable = True def as_array(self): # self._array[:] = c.FREE_CELL.value for inv_idx, battery in enumerate(self): self._array[inv_idx] = battery.as_array() return self._array def spawn_batteries(self, agents, pomdp_r, initial_charge_level): inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent, initial_charge_level) for _, agent in enumerate(agents)] self.register_additional_items(inventories) def idx_by_entity(self, entity): try: return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity))) except StopIteration: return None def by_entity(self, entity): try: return next((bat for bat in self if bat.belongs_to_entity(entity))) except StopIteration: return None def summarize_states(self, n_steps=None): # as dict with additional nesting # return dict(items=super(Inventories, self).summarize_states()) return super(BatteriesRegister, self).summarize_states(n_steps=n_steps) class ChargePod(Entity): @property def can_collide(self): return False @property def encoding(self): return ITEM_DROP_OFF def __init__(self, *args, charge_rate: float = 0.4, multi_charge: bool = False, **kwargs): super(ChargePod, self).__init__(*args, **kwargs) self.charge_rate = charge_rate self.multi_charge = multi_charge def charge_battery(self, battery: Battery): if battery.charge_level == 1.0: return c.NOT_VALID if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1: return c.NOT_VALID battery.charge(self.charge_rate) return c.VALID def summarize_state(self, n_steps=None) -> dict: if n_steps == h.STEPS_START: summary = super().summarize_state(n_steps=n_steps) return summary class ChargePods(EntityObjectRegister): _accepted_objects = ChargePod def as_array(self): self._array[:] = c.FREE_CELL.value for item in self: if item.pos != c.NO_POS.value: self._array[0, item.x, item.y] = item.encoding return self._array def __repr__(self): super(ChargePods, self).__repr__() class BatteryFactory(BaseFactory): def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs): if isinstance(btry_prop, dict): btry_prop = BatteryProperties(**btry_prop) self.btry_prop = btry_prop super().__init__(*args, **kwargs) @property def additional_entities(self): super_entities = super().additional_entities empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations] charge_pods = ChargePods.from_tiles( empty_tiles, self._level_shape, entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate, multi_charge=self.btry_prop.multi_charge) ) batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2), ) batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge) super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods}) return super_entities def do_additional_step(self) -> dict: info_dict = super(BatteryFactory, self).do_additional_step() # Decharge batteries = self[c.BATTERIES] for agent in self[c.AGENT]: if isinstance(self.btry_prop.per_action_costs, dict): energy_consumption = self.btry_prop.per_action_costs[agent.temp_action] else: energy_consumption = self.btry_prop.per_action_costs batteries.by_entity(agent).decharge(energy_consumption) return info_dict def do_charge(self, agent) -> c: if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos): return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent)) else: return c.NOT_VALID def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]: valid = super().do_additional_actions(agent, action) if valid is None: if action == CHARGE_ACTION: valid = self.do_charge(agent) return valid else: return None else: return valid pass def do_additional_reset(self) -> None: # There is Nothing to reset. pass def check_additional_done(self) -> bool: super_done = super(BatteryFactory, self).check_additional_done() if super_done: return super_done else: return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES]) pass def calculate_additional_reward(self, agent: Agent) -> (int, dict): reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent) if h.EnvActions.CHARGE == agent.temp_action: if agent.temp_valid: charge_pod = self[c.CHARGE_POD].by_pos(agent.pos) info_dict.update({f'{agent.name}_charge': 1}) info_dict.update(agent_charged=1) self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.') reward += 0.1 else: self[c.DROP_OFF].by_pos(agent.pos) info_dict.update({f'{agent.name}_failed_charge': 1}) info_dict.update(failed_charge=1) self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.') reward -= 0.1 if self[c.BATTERIES].by_entity(agent).is_discharged: info_dict.update({f'{agent.name}_discharged': 1}) reward -= 1 else: info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level}) return reward, info_dict def render_additional_assets(self): # noinspection PyUnresolvedReferences additional_assets = super().render_additional_assets() charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]] additional_assets.extend(charge_pods) return additional_assets