Rework for performance
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from typing import Union, NamedTuple, Dict
|
||||
from typing import Union, NamedTuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -6,13 +6,29 @@ from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.helpers import Rewards as BaseRewards
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
CHARGE_ACTION = h.EnvActions.CHARGE
|
||||
CHARGE_POD = 1
|
||||
class Constants(BaseConstants):
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class Rewards(BaseRewards):
|
||||
CHARGE_VALID = 0.1
|
||||
CHARGE_FAIL = -0.1
|
||||
BATTERY_DISCHARGED = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
@ -24,7 +40,12 @@ class BatteryProperties(NamedTuple):
|
||||
multi_charge: bool = False
|
||||
|
||||
|
||||
class Battery(EnvObject, BoundingMixin):
|
||||
c = Constants
|
||||
a = Actions
|
||||
r = Rewards
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
@ -37,13 +58,13 @@ class Battery(EnvObject, BoundingMixin):
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def charge(self, amount) -> c:
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
@ -54,7 +75,7 @@ class Battery(EnvObject, BoundingMixin):
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
@ -63,53 +84,43 @@ class Battery(EnvObject, BoundingMixin):
|
||||
class BatteriesRegister(EnvObjectRegister):
|
||||
|
||||
_accepted_objects = Battery
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def as_array(self):
|
||||
# ToDO: Make this Lazy
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for inv_idx, battery in enumerate(self):
|
||||
self._array[inv_idx] = battery.as_array()
|
||||
return self._array
|
||||
|
||||
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
||||
batteries = [self._accepted_objects(pomdp_r, self._shape, agent,
|
||||
initial_charge_level)
|
||||
for _, agent in enumerate(agents)]
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(batteries)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((bat for bat in self if bat.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return CHARGE_POD
|
||||
return c.CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
@ -120,9 +131,9 @@ class ChargePod(Entity):
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
battery.charge(self.charge_rate)
|
||||
battery.do_charge_action(self.charge_rate)
|
||||
return c.VALID
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
@ -135,14 +146,6 @@ class ChargePods(EntityRegister):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
@DeprecationWarning
|
||||
def Xas_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
||||
|
||||
@ -155,14 +158,14 @@ class BatteryFactory(BaseFactory):
|
||||
self.btry_prop = btry_prop
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()})
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()})
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
@ -178,12 +181,12 @@ class BatteryFactory(BaseFactory):
|
||||
|
||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def do_additional_step(self) -> dict:
|
||||
info_dict = super(BatteryFactory, self).do_additional_step()
|
||||
def do_additional_step(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).do_additional_step()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
@ -196,65 +199,70 @@ class BatteryFactory(BaseFactory):
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
return info_dict
|
||||
return super_reward_info
|
||||
|
||||
def do_charge(self, agent) -> c:
|
||||
if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
|
||||
return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
def do_charge_action(self, agent) -> (dict, dict):
|
||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
if valid:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
|
||||
valid = super().do_additional_actions(agent, action)
|
||||
if valid is None:
|
||||
if action == CHARGE_ACTION:
|
||||
valid = self.do_charge(agent)
|
||||
return valid
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CHARGE:
|
||||
action_result = self.do_charge_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return valid
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def do_additional_reset(self) -> None:
|
||||
# There is Nothing to reset.
|
||||
pass
|
||||
|
||||
def check_additional_done(self) -> bool:
|
||||
super_done = super(BatteryFactory, self).check_additional_done()
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||
if super_done:
|
||||
return super_done
|
||||
return super_done, super_dict
|
||||
else:
|
||||
return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
|
||||
if self.btry_prop.done_when_discharged:
|
||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||
super_dict.update(DISCHARGE_DONE=1)
|
||||
return btry_done, super_dict
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
|
||||
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
|
||||
reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
|
||||
if h.EnvActions.CHARGE == agent.temp_action:
|
||||
if agent.temp_valid:
|
||||
charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
|
||||
info_dict.update({f'{agent.name}_charge': 1})
|
||||
info_dict.update(agent_charged=1)
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
|
||||
reward += 0.1
|
||||
else:
|
||||
self[c.DROP_OFF].by_pos(agent.pos)
|
||||
info_dict.update({f'{agent.name}_failed_charge': 1})
|
||||
info_dict.update(failed_charge=1)
|
||||
self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
|
||||
reward -= 0.1
|
||||
|
||||
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
|
||||
reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
info_dict.update({f'{agent.name}_discharged': 1})
|
||||
reward -= 1
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
|
||||
else:
|
||||
info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
|
||||
return reward, info_dict
|
||||
# All Fine
|
||||
pass
|
||||
return reward_event_dict
|
||||
|
||||
def render_additional_assets(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_additional_assets()
|
||||
charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
||||
|
||||
|
Reference in New Issue
Block a user