Rework for performance

This commit is contained in:
Steffen Illium
2022-01-10 15:54:22 +01:00
parent 78bf19f7f4
commit 435056f373
10 changed files with 525 additions and 469 deletions

View File

@ -1,4 +1,4 @@
from typing import Union, NamedTuple, Dict
from typing import Union, NamedTuple, Dict, List
import numpy as np
@ -6,13 +6,29 @@ from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as c, Constants
from environments.helpers import Constants as BaseConstants
from environments.helpers import EnvActions as BaseActions
from environments.helpers import Rewards as BaseRewards
from environments import helpers as h
CHARGE_ACTION = h.EnvActions.CHARGE
CHARGE_POD = 1
class Constants(BaseConstants):
# Battery Env
CHARGE_PODS = 'Charge_Pod'
BATTERIES = 'BATTERIES'
BATTERY_DISCHARGED = 'DISCHARGED'
CHARGE_POD = 1
class Actions(BaseActions):
CHARGE = 'do_charge_action'
class Rewards(BaseRewards):
CHARGE_VALID = 0.1
CHARGE_FAIL = -0.1
BATTERY_DISCHARGED = -1.0
class BatteryProperties(NamedTuple):
@ -24,7 +40,12 @@ class BatteryProperties(NamedTuple):
multi_charge: bool = False
class Battery(EnvObject, BoundingMixin):
c = Constants
a = Actions
r = Rewards
class Battery(BoundingMixin, EnvObject):
@property
def is_discharged(self):
@ -37,13 +58,13 @@ class Battery(EnvObject, BoundingMixin):
def encoding(self):
return self.charge_level
def charge(self, amount) -> c:
def do_charge_action(self, amount):
if self.charge_level < 1:
# noinspection PyTypeChecker
self.charge_level = min(1, amount + self.charge_level)
return c.VALID
return dict(valid=c.VALID, action=a.CHARGE, reward=r.CHARGE_VALID)
else:
return c.NOT_VALID
return dict(valid=c.NOT_VALID, action=a.CHARGE, reward=r.CHARGE_FAIL)
def decharge(self, amount) -> c:
if self.charge_level != 0:
@ -54,7 +75,7 @@ class Battery(EnvObject, BoundingMixin):
else:
return c.NOT_VALID
def summarize_state(self, **kwargs):
def summarize_state(self, **_):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update(dict(name=self.name))
return attr_dict
@ -63,53 +84,43 @@ class Battery(EnvObject, BoundingMixin):
class BatteriesRegister(EnvObjectRegister):
_accepted_objects = Battery
is_blocking_light = False
can_be_shadowed = False
def __init__(self, *args, **kwargs):
super(BatteriesRegister, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
is_blocking_light=False, can_be_shadowed=False, **kwargs)
self.is_observable = True
def as_array(self):
# ToDO: Make this Lazy
self._array[:] = c.FREE_CELL.value
for inv_idx, battery in enumerate(self):
self._array[inv_idx] = battery.as_array()
return self._array
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
batteries = [self._accepted_objects(pomdp_r, self._shape, agent,
initial_charge_level)
for _, agent in enumerate(agents)]
def spawn_batteries(self, agents, initial_charge_level):
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
self.register_additional_items(batteries)
def idx_by_entity(self, entity):
try:
return next((idx for idx, bat in enumerate(self) if bat.belongs_to_entity(entity)))
except StopIteration:
return None
def by_entity(self, entity):
try:
return next((bat for bat in self if bat.belongs_to_entity(entity)))
except StopIteration:
return None
def summarize_states(self, n_steps=None):
# as dict with additional nesting
# return dict(items=super(Inventories, cls).summarize_states())
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
# Todo Move this to Mixin!
def by_entity(self, entity):
try:
return next((x for x in self if x.belongs_to_entity(entity)))
except StopIteration:
return None
def idx_by_entity(self, entity):
try:
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
except StopIteration:
return None
def as_array_by_entity(self, entity):
return self._array[self.idx_by_entity(entity)]
class ChargePod(Entity):
@property
def can_collide(self):
return False
@property
def encoding(self):
return CHARGE_POD
return c.CHARGE_POD
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
@ -120,9 +131,9 @@ class ChargePod(Entity):
def charge_battery(self, battery: Battery):
if battery.charge_level == 1.0:
return c.NOT_VALID
if sum(guest for guest in self.tile.guests if c.AGENT.name in guest.name) > 1:
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
return c.NOT_VALID
battery.charge(self.charge_rate)
battery.do_charge_action(self.charge_rate)
return c.VALID
def summarize_state(self, n_steps=None) -> dict:
@ -135,14 +146,6 @@ class ChargePods(EntityRegister):
_accepted_objects = ChargePod
@DeprecationWarning
def Xas_array(self):
self._array[:] = c.FREE_CELL.value
for item in self:
if item.pos != c.NO_POS.value:
self._array[0, item.x, item.y] = item.encoding
return self._array
def __repr__(self):
super(ChargePods, self).__repr__()
@ -155,14 +158,14 @@ class BatteryFactory(BaseFactory):
self.btry_prop = btry_prop
super().__init__(*args, **kwargs)
def _additional_per_agent_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_per_agent_raw_observations(agent)
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()})
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
return additional_raw_observations
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
def _additional_observations(self) -> Dict[str, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()})
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
return additional_observations
@property
@ -178,12 +181,12 @@ class BatteryFactory(BaseFactory):
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
)
batteries.spawn_batteries(self[c.AGENT], self._pomdp_r, self.btry_prop.initial_charge)
super_entities.update({c.BATTERIES: batteries, c.CHARGE_POD: charge_pods})
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
return super_entities
def do_additional_step(self) -> dict:
info_dict = super(BatteryFactory, self).do_additional_step()
def do_additional_step(self) -> (List[dict], dict):
super_reward_info = super(BatteryFactory, self).do_additional_step()
# Decharge
batteries = self[c.BATTERIES]
@ -196,65 +199,70 @@ class BatteryFactory(BaseFactory):
batteries.by_entity(agent).decharge(energy_consumption)
return info_dict
return super_reward_info
def do_charge(self, agent) -> c:
if charge_pod := self[c.CHARGE_POD].by_pos(agent.pos):
return charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
def do_charge_action(self, agent) -> (dict, dict):
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
if valid:
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
else:
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
else:
return c.NOT_VALID
valid = c.NOT_VALID
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
# info_dict = {f'{agent.name}_no_charger': 1}
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
reward = dict(value=r.CHARGE_VALID if valid else r.CHARGE_FAIL, reason=a.CHARGE, info=info_dict)
return valid, reward
def do_additional_actions(self, agent: Agent, action: Action) -> Union[None, c]:
valid = super().do_additional_actions(agent, action)
if valid is None:
if action == CHARGE_ACTION:
valid = self.do_charge(agent)
return valid
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
action_result = super().do_additional_actions(agent, action)
if action_result is None:
if action == a.CHARGE:
action_result = self.do_charge_action(agent)
return action_result
else:
return None
else:
return valid
return action_result
pass
def do_additional_reset(self) -> None:
# There is Nothing to reset.
pass
def check_additional_done(self) -> bool:
super_done = super(BatteryFactory, self).check_additional_done()
def check_additional_done(self) -> (bool, dict):
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
if super_done:
return super_done
return super_done, super_dict
else:
return self.btry_prop.done_when_discharged and any(battery.is_discharged for battery in self[c.BATTERIES])
if self.btry_prop.done_when_discharged:
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
super_dict.update(DISCHARGE_DONE=1)
return btry_done, super_dict
else:
pass
else:
pass
pass
def calculate_additional_reward(self, agent: Agent) -> (int, dict):
reward, info_dict = super(BatteryFactory, self).calculate_additional_reward(agent)
if h.EnvActions.CHARGE == agent.temp_action:
if agent.temp_valid:
charge_pod = self[c.CHARGE_POD].by_pos(agent.pos)
info_dict.update({f'{agent.name}_charge': 1})
info_dict.update(agent_charged=1)
self.print(f'{agent.name} just charged batteries at {charge_pod.pos}.')
reward += 0.1
else:
self[c.DROP_OFF].by_pos(agent.pos)
info_dict.update({f'{agent.name}_failed_charge': 1})
info_dict.update(failed_charge=1)
self.print(f'{agent.name} just tried to charge at {agent.pos}, but failed.')
reward -= 0.1
def additional_per_agent_reward(self, agent: Agent) -> Dict[str, dict]:
reward_event_dict = super(BatteryFactory, self).additional_per_agent_reward(agent)
if self[c.BATTERIES].by_entity(agent).is_discharged:
info_dict.update({f'{agent.name}_discharged': 1})
reward -= 1
self.print(f'{agent.name} Battery is discharged!')
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': r.BATTERY_DISCHARGED, 'info': info_dict}})
else:
info_dict.update({f'{agent.name}_battery_level': self[c.BATTERIES].by_entity(agent).charge_level})
return reward, info_dict
# All Fine
pass
return reward_event_dict
def render_additional_assets(self):
# noinspection PyUnresolvedReferences
additional_assets = super().render_additional_assets()
charge_pods = [RenderEntity(c.CHARGE_POD.value, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_POD]]
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
additional_assets.extend(charge_pods)
return additional_assets