276 lines
9.3 KiB
Python

from typing import Union, NamedTuple
import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as c
from environments import helpers as h
CHARGE_ACTION = h.EnvActions.CHARGE
ITEM_DROP_OFF = 1
class BatteryProperties(NamedTuple):
initial_charge: float = 0.8 #
charge_rate: float = 0.4 #
charge_locations: int = 20 #
per_action_costs: Union[dict, float] = 0.02
done_when_discharged = False
multi_charge: bool = False
class Battery(object):
@property
def is_discharged(self):
return self.charge_level == 0
@property
def is_blocking_light(self):
return False
@property
def can_collide(self):
return False
@property
def name(self):
return f'{self.__class__.__name__}({self.agent.name})'
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
super().__init__()
self.agent = agent
self._pomdp_r = pomdp_r
self._level_shape = level_shape
if self._pomdp_r:
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
else:
self._array = np.zeros((1, *self._level_shape))
self.charge_level = initial_charge_level
def as_array(self):
self._array[:] = c.FREE_CELL.value
self._array[0, 0] = self.charge_level
return self._array
def __repr__(self):
return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
def charge(self, amount) -> c:
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
return c.VALID
else:
return c.NOT_VALID
def decharge(self, amount) -> c:
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
return c.VALID
else:
return c.NOT_VALID
def belongs_to_entity(self, entity):
return self.agent == entity
def summarize_state(self, **kwargs):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(name=self.name))
return attr_dict
class BatteriesRegister(ObjectRegister):
_accepted_objects = Battery
is_blocking_light = False
can_be_shadowed = False
hide_from_obs_builder = True
def __init__(self, *args, **kwargs):
super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
self.is_observable = True
def as_array(self):
# self._array[:] = c.FREE_CELL.value
for inv_idx, battery in enumerate(self):
self._array[inv_idx] = battery.as_array()
return self._array
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
initial_charge_level)
for _, agent in enumerate(agents)]
self.register_additional_items(inventories)
def idx_by_entity(self, entity):
try:
return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
except StopIteration:
return None
def by_entity(self, entity):
try:
return next((bat for bat in self if bat.belongs_to_entity(entity)))
except StopIteration:
return None
def summarize_states(self, n_steps=None):
# as dict with additional nesting
# return dict(items=super(Inventories, self).summarize_states())
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
class ChargePod(Entity):
@property
def can_collide(self):
return False
@property
def encoding(self):
return ITEM_DROP_OFF
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
super(ChargePod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate
self.multi_charge = multi_charge
def charge_battery(self, battery: Battery):
if battery.charge_level == 1.0:
return c.NOT_VALID
if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
return c.NOT_VALID
battery.charge(self.charge_rate)
return c.VALID
def summarize_state(self, n_steps=None) -> dict:
if n_steps == h.STEPS_START:
summary = super().summarize_state(n_steps=n_steps)
return summary
class ChargePods(EntityObjectRegister):
_accepted_objects = ChargePod
def as_array(self):
self._array[:] = c.FREE_CELL.value
for item in self:
if item.pos != c.NO_POS.value:
self._array[0, item.x, item.y] = item.encoding
return self._array
def __repr__(self):
super(ChargePods, self).__repr__()
class BatteryFactory(BaseFactory):
def __init__(self, *args, btry_prop=BatteryProperties(), **kwargs):
if isinstance(btry_prop, dict):
btry_prop = BatteryProperties(**btry_prop)
self.btry_prop = btry_prop
super().__init__(*args, **kwargs)
@property
def additional_entities(self):
super_entities = super().additional_entities
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
charge_pods = ChargePods.from_tiles(
empty_tiles, self._level_shape,
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
multi_charge=self.btry_prop.multi_charge)
)
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
)
batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
return super_entities
def do_additional_step(self) -> dict:
info_dict = super(BatteryFactory, self).do_additional_step()
# Decharge
batteries = self[c.BATTERIES]
for agent in self[c.AGENT]:
if isinstance(self.btry_prop.per_action_costs, dict):
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
else:
energy_consumption = self.btry_prop.per_action_costs
batteries.by_entity(agent).decharge(energy_consumption)
return info_dict
def do_charge(self, agent) -> c:
if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
else:
return c.NOT_VALID
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
valid = super().do_additional_actions(agent, action)
if valid is None:
if action == CHARGE_ACTION:
valid = self.do_charge(agent)
return valid
else:
return None
else:
return valid
pass
def do_additional_reset(self) -> None:
# There is Nothing to reset.
pass
def check_additional_done(self) -> bool:
super_done = super(BatteryFactory, self).check_additional_done()
if super_done:
return super_done
else:
return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
pass
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
if h.EnvActions.CHARGE == agent.temp_action:
if agent.temp_valid:
charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
info_dict.update({f'{agent.name}_charge': 1})
info_dict.update(agent_charged=1)
self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
reward += 0.1
else:
self[c.DROP_OFF].by_pos(agent.pos)
info_dict.update({f'{agent.name}_failed_charge': 1})
info_dict.update(failed_charge=1)
self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
reward -= 0.1
if self[c.BATTERIES].by_entity(agent).is_discharged:
info_dict.update({f'{agent.name}_discharged': 1})
reward -= 1
else:
info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
return reward, info_dict
def render_additional_assets(self):
# noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets()
charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
additional_assets.extend(charge_pods)
return additional_assets