274 lines
9.9 KiB
Python
274 lines
9.9 KiB
Python
from typing import Union, NamedTuple, Dict, List
|
|
|
|
import numpy as np
|
|
|
|
from environments.factory.base.base_factory import BaseFactory
|
|
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
|
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
|
|
from environments.factory.base.renderer import RenderEntity
|
|
from environments.helpers import Constants as BaseConstants
|
|
from environments.helpers import EnvActions as BaseActions
|
|
|
|
from environments import helpers as h
|
|
|
|
|
|
class Constants(BaseConstants):
|
|
# Battery Env
|
|
CHARGE_PODS = 'Charge_Pod'
|
|
BATTERIES = 'BATTERIES'
|
|
BATTERY_DISCHARGED = 'DISCHARGED'
|
|
CHARGE_POD = 1
|
|
|
|
|
|
class Actions(BaseActions):
|
|
CHARGE = 'do_charge_action'
|
|
|
|
|
|
class RewardsBtry(NamedTuple):
|
|
CHARGE_VALID: float = 0.1
|
|
CHARGE_FAIL: float = -0.1
|
|
BATTERY_DISCHARGED: float = -1.0
|
|
|
|
|
|
class BatteryProperties(NamedTuple):
|
|
initial_charge: float = 0.8 #
|
|
charge_rate: float = 0.4 #
|
|
charge_locations: int = 20 #
|
|
per_action_costs: Union[dict, float] = 0.02
|
|
done_when_discharged = False
|
|
multi_charge: bool = False
|
|
|
|
|
|
c = Constants
|
|
a = Actions
|
|
|
|
|
|
class Battery(BoundingMixin, EnvObject):
|
|
|
|
@property
|
|
def is_discharged(self):
|
|
return self.charge_level == 0
|
|
|
|
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
|
super(Battery, self).__init__(*args, **kwargs)
|
|
self.charge_level = initial_charge_level
|
|
|
|
def encoding(self):
|
|
return self.charge_level
|
|
|
|
def do_charge_action(self, amount):
|
|
if self.charge_level < 1:
|
|
# noinspection PyTypeChecker
|
|
self.charge_level = min(1, amount + self.charge_level)
|
|
return c.VALID
|
|
else:
|
|
return c.NOT_VALID
|
|
|
|
def decharge(self, amount) -> c:
|
|
if self.charge_level != 0:
|
|
# noinspection PyTypeChecker
|
|
self.charge_level = max(0, amount + self.charge_level)
|
|
self._collection.notify_change_to_value(self)
|
|
return c.VALID
|
|
else:
|
|
return c.NOT_VALID
|
|
|
|
def summarize_state(self, **_):
|
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
|
attr_dict.update(dict(name=self.name))
|
|
return attr_dict
|
|
|
|
|
|
class BatteriesRegister(EnvObjectCollection):
|
|
|
|
_accepted_objects = Battery
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
|
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
|
self.is_observable = True
|
|
|
|
def spawn_batteries(self, agents, initial_charge_level):
|
|
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
|
self.add_additional_items(batteries)
|
|
|
|
def summarize_states(self, n_steps=None):
|
|
# as dict with additional nesting
|
|
# return dict(items=super(Inventories, cls).summarize_states())
|
|
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
|
|
|
# Todo Move this to Mixin!
|
|
def by_entity(self, entity):
|
|
try:
|
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
|
except StopIteration:
|
|
return None
|
|
|
|
def idx_by_entity(self, entity):
|
|
try:
|
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
|
except StopIteration:
|
|
return None
|
|
|
|
def as_array_by_entity(self, entity):
|
|
return self._array[self.idx_by_entity(entity)]
|
|
|
|
|
|
class ChargePod(Entity):
|
|
|
|
@property
|
|
def encoding(self):
|
|
return c.CHARGE_POD
|
|
|
|
def __init__(self, *args, charge_rate: float = 0.4,
|
|
multi_charge: bool = False, **kwargs):
|
|
super(ChargePod, self).__init__(*args, **kwargs)
|
|
self.charge_rate = charge_rate
|
|
self.multi_charge = multi_charge
|
|
|
|
def charge_battery(self, battery: Battery):
|
|
if battery.charge_level == 1.0:
|
|
return c.NOT_VALID
|
|
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
|
return c.NOT_VALID
|
|
valid = battery.do_charge_action(self.charge_rate)
|
|
return valid
|
|
|
|
def summarize_state(self, n_steps=None) -> dict:
|
|
if n_steps == h.STEPS_START:
|
|
summary = super().summarize_state(n_steps=n_steps)
|
|
return summary
|
|
|
|
|
|
class ChargePods(EntityCollection):
|
|
|
|
_accepted_objects = ChargePod
|
|
|
|
def __repr__(self):
|
|
super(ChargePods, self).__repr__()
|
|
|
|
|
|
class BatteryFactory(BaseFactory):
|
|
|
|
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
|
**kwargs):
|
|
if isinstance(btry_prop, dict):
|
|
btry_prop = BatteryProperties(**btry_prop)
|
|
if isinstance(rewards_dest, dict):
|
|
rewards_dest = BatteryProperties(**rewards_dest)
|
|
self.btry_prop = btry_prop
|
|
self.rewards_dest = rewards_dest
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
|
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
|
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
|
return additional_raw_observations
|
|
|
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
|
additional_observations = super().observations_hook()
|
|
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
|
return additional_observations
|
|
|
|
@property
|
|
def entities_hook(self):
|
|
super_entities = super().entities_hook
|
|
|
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
|
charge_pods = ChargePods.from_tiles(
|
|
empty_tiles, self._level_shape,
|
|
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
|
multi_charge=self.btry_prop.multi_charge)
|
|
)
|
|
|
|
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
|
)
|
|
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
|
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
|
return super_entities
|
|
|
|
def step_hook(self) -> (List[dict], dict):
|
|
super_reward_info = super(BatteryFactory, self).step_hook()
|
|
|
|
# Decharge
|
|
batteries = self[c.BATTERIES]
|
|
|
|
for agent in self[c.AGENT]:
|
|
if isinstance(self.btry_prop.per_action_costs, dict):
|
|
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
|
else:
|
|
energy_consumption = self.btry_prop.per_action_costs
|
|
|
|
batteries.by_entity(agent).decharge(energy_consumption)
|
|
|
|
return super_reward_info
|
|
|
|
def do_charge_action(self, agent) -> (dict, dict):
|
|
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
|
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
|
if valid:
|
|
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
|
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
|
else:
|
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
|
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
|
else:
|
|
valid = c.NOT_VALID
|
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
|
# info_dict = {f'{agent.name}_no_charger': 1}
|
|
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
|
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
|
reason=a.CHARGE, info=info_dict)
|
|
return valid, reward
|
|
|
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
|
action_result = super().do_additional_actions(agent, action)
|
|
if action_result is None:
|
|
if action == a.CHARGE:
|
|
action_result = self.do_charge_action(agent)
|
|
return action_result
|
|
else:
|
|
return None
|
|
else:
|
|
return action_result
|
|
pass
|
|
|
|
def reset_hook(self) -> None:
|
|
# There is Nothing to reset.
|
|
pass
|
|
|
|
def check_additional_done(self) -> (bool, dict):
|
|
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
|
if super_done:
|
|
return super_done, super_dict
|
|
else:
|
|
if self.btry_prop.done_when_discharged:
|
|
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
|
super_dict.update(DISCHARGE_DONE=1)
|
|
return btry_done, super_dict
|
|
else:
|
|
pass
|
|
else:
|
|
pass
|
|
pass
|
|
|
|
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
|
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
|
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
|
self.print(f'{agent.name} Battery is discharged!')
|
|
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
|
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
|
'info': info_dict}}
|
|
)
|
|
else:
|
|
# All Fine
|
|
pass
|
|
return reward_event_dict
|
|
|
|
def render_assets_hook(self):
|
|
# noinspection PyUnresolvedReferences
|
|
additional_assets = super().render_assets_hook()
|
|
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
|
additional_assets.extend(charge_pods)
|
|
return additional_assets
|