2022-01-10 15:54:22 +01:00

269 lines
9.8 KiB
Python

from typing import Union, NamedTuple, Dict, List
import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments import helpers as h
class Constants(BaseConstants):
# Battery Env
CHARGE_PODS = 'Charge_Pod'
BATTERIES = 'BATTERIES'
BATTERY_DISCHARGED = 'DISCHARGED'
CHARGE_POD = 1
class Actions(BaseActions):
CHARGE = 'do_charge_action'
class Rewards(BaseRewards):
CHARGE_VALID = 0.1
CHARGE_FAIL = -0.1
BATTERY_DISCHARGED = -1.0
class BatteryProperties(NamedTuple):
initial_charge: float = 0.8 #
charge_rate: float = 0.4 #
charge_locations: int = 20 #
per_action_costs: Union[dict, float] = 0.02
done_when_discharged = False
multi_charge: bool = False
c = Constants
a = Actions
r = Rewards
class Battery(BoundingMixin, EnvObject):
@property
def is_discharged(self):
return self.charge_level == 0
def __init__(self, initial_charge_level: float, *args, **kwargs):
super(Battery, self).__init__(*args, **kwargs)
self.charge_level = initial_charge_level
def encoding(self):
return self.charge_level
def do_charge_action(self, amount):
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
else:
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
def decharge(self, amount) -> c:
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
self._register.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
def summarize_state(self, **_):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(name=self.name))
return attr_dict
class BatteriesRegister(EnvObjectRegister):
_accepted_objects = Battery
def __init__(self, *args, **kwargs):
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
is_blocking_light=False, can_be_shadowed=False, **kwargs)
self.is_observable = True
def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
self.register_additional_items(batteries)
def summarize_states(self, n_steps=None):
# as dict with additional nesting
# return dict(items=super(Inventories, cls).summarize_states())
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
# Todo Move this to Mixin!
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except StopIteration:
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except StopIteration:
return None
def as_array_by_entity(self, entity):
return self._array[self.idx_by_entity(entity)]
class ChargePod(Entity):
@property
def encoding(self):
return c.CHARGE_POD
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge
def charge_battery(self, battery: Battery):
if battery.charge_level == 1.0:
return c.NOT_VALID
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
return c.NOT_VALID
battery.do_charge_action(self.charge_rate)
return c.VALID
def summarize_state(self, n_steps=None) -> dict:
if n_steps == h.STEPS_START:
summary = super().summarize_state(n_steps=n_steps)
return summary
class ChargePods(EntityRegister):
_accepted_objects = ChargePod
def __repr__(self):
super(ChargePods, self).__repr__()
class BatteryFactory(BaseFactory):
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
if isinstance(btry_prop, dict):
btry_prop = BatteryProperties(**btry_prop)
self.btry_prop = btry_prop
super().__init__(*args, **kwargs)
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
return additional_raw_observations
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
return additional_observations
@property
def additional_entities(self):
super_entities = super().additional_entities
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
charge_pods = ChargePods.from_tiles(
empty_tiles, self._level_shape,
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
multi_charge=self.btry_prop.multi_charge)
)
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
)
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
return super_entities
def do_additional_step(self) -> (List[dict], dict):
super_reward_info = super(BatteryFactory, self).do_additional_step()
# Decharge
batteries = self[c.BATTERIES]
for agent in self[c.AGENT]:
if isinstance(self.btry_prop.per_action_costs, dict):
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
else:
energy_consumption = self.btry_prop.per_action_costs
batteries.by_entity(agent).decharge(energy_consumption)
return super_reward_info
def do_charge_action(self, agent) -> (dict, dict):
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
if valid:
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
else:
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
else:
valid = c.NOT_VALID
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
# info_dict = {f'{agent.name}_no_charger': 1}
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
action_result = super().do_additional_actions(agent, action)
if action_result is None:
if action == a.CHARGE:
action_result = self.do_charge_action(agent)
return action_result
else:
return None
else:
return action_result
pass
def do_additional_reset(self) -> None:
# There is Nothing to reset.
pass
def check_additional_done(self) -> (bool, dict):
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
if super_done:
return super_done, super_dict
else:
if self.btry_prop.done_when_discharged:
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
super_dict.update(DISCHARGE_DONE=1)
return btry_done, super_dict
else:
pass
else:
pass
pass
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent)
if self[c.BATTERIES].by_entity(agent).is_discharged:
self.print(f'{agent.name} Battery is discharged!')
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
else:
# All Fine
pass
return reward_event_dict
def render_additional_assets(self):
# noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets()
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
additional_assets.extend(charge_pods)
return additional_assets