2022-04-11 16:15:44 +02:00

274 lines
9.9 KiB
Python

from typing import Union, NamedTuple, Dict, List
import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments import helpers as h
class Constants(BaseConstants):
# Battery Env
CHARGE_PODS = 'Charge_Pod'
BATTERIES = 'BATTERIES'
BATTERY_DISCHARGED = 'DISCHARGED'
CHARGE_POD = 1
class Actions(BaseActions):
CHARGE = 'do_charge_action'
class RewardsBtry(NamedTuple):
CHARGE_VALID: float = 0.1
CHARGE_FAIL: float = -0.1
BATTERY_DISCHARGED: float = -1.0
class BatteryProperties(NamedTuple):
initial_charge: float = 0.8 #
charge_rate: float = 0.4 #
charge_locations: int = 20 #
per_action_costs: Union[dict, float] = 0.02
done_when_discharged = False
multi_charge: bool = False
c = Constants
a = Actions
class Battery(BoundingMixin, EnvObject):
@property
def is_discharged(self):
return self.charge_level == 0
def __init__(self, initial_charge_level: float, *args, **kwargs):
super(Battery, self).__init__(*args, **kwargs)
self.charge_level = initial_charge_level
def encoding(self):
return self.charge_level
def do_charge_action(self, amount):
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
return c.VALID
else:
return c.NOT_VALID
def decharge(self, amount) -> c:
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
self._collection.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
def summarize_state(self, **_):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(name=self.name))
return attr_dict
class BatteriesRegister(EnvObjectCollection):
_accepted_objects = Battery
def __init__(self, *args, **kwargs):
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
is_blocking_light=False, can_be_shadowed=False, **kwargs)
self.is_observable = True
def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
self.add_additional_items(batteries)
def summarize_states(self, n_steps=None):
# as dict with additional nesting
# return dict(items=super(Inventories, cls).summarize_states())
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
# Todo Move this to Mixin!
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except StopIteration:
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except StopIteration:
return None
def as_array_by_entity(self, entity):
return self._array[self.idx_by_entity(entity)]
class ChargePod(Entity):
@property
def encoding(self):
return c.CHARGE_POD
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge
def charge_battery(self, battery: Battery):
if battery.charge_level == 1.0:
return c.NOT_VALID
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
return c.NOT_VALID
valid = battery.do_charge_action(self.charge_rate)
return valid
def summarize_state(self, n_steps=None) -> dict:
if n_steps == h.STEPS_START:
summary = super().summarize_state(n_steps=n_steps)
return summary
class ChargePods(EntityCollection):
_accepted_objects = ChargePod
def __repr__(self):
super(ChargePods, self).__repr__()
class BatteryFactory(BaseFactory):
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
**kwargs):
if isinstance(btry_prop, dict):
btry_prop = BatteryProperties(**btry_prop)
if isinstance(rewards_dest, dict):
rewards_dest = BatteryProperties(**rewards_dest)
self.btry_prop = btry_prop
self.rewards_dest = rewards_dest
super().__init__(*args, **kwargs)
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
return additional_raw_observations
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super().observations_hook()
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
return additional_observations
@property
def entities_hook(self):
super_entities = super().entities_hook
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
charge_pods = ChargePods.from_tiles(
empty_tiles, self._level_shape,
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
multi_charge=self.btry_prop.multi_charge)
)
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
)
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
return super_entities
def step_hook(self) -> (List[dict], dict):
super_reward_info = super(BatteryFactory, self).step_hook()
# Decharge
batteries = self[c.BATTERIES]
for agent in self[c.AGENT]:
if isinstance(self.btry_prop.per_action_costs, dict):
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
else:
energy_consumption = self.btry_prop.per_action_costs
batteries.by_entity(agent).decharge(energy_consumption)
return super_reward_info
def do_charge_action(self, agent) -> (dict, dict):
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
if valid:
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
else:
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
else:
valid = c.NOT_VALID
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
# info_dict = {f'{agent.name}_no_charger': 1}
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
reason=a.CHARGE, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
action_result = super().do_additional_actions(agent, action)
if action_result is None:
if action == a.CHARGE:
action_result = self.do_charge_action(agent)
return action_result
else:
return None
else:
return action_result
pass
def reset_hook(self) -> None:
# There is Nothing to reset.
pass
def check_additional_done(self) -> (bool, dict):
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
if super_done:
return super_done, super_dict
else:
if self.btry_prop.done_when_discharged:
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
super_dict.update(DISCHARGE_DONE=1)
return btry_done, super_dict
else:
pass
else:
pass
pass
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
if self[c.BATTERIES].by_entity(agent).is_discharged:
self.print(f'{agent.name} Battery is discharged!')
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
'info': info_dict}}
)
else:
# All Fine
pass
return reward_event_dict
def render_assets_hook(self):
# noinspection PyUnresolvedReferences
additional_assets = super().render_assets_hook()
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
additional_assets.extend(charge_pods)
return additional_assets