Rework of Observations and Entity Differentiation, lazy obs build by notification

This commit is contained in:
Steffen Illium
2021-12-22 10:48:36 +01:00
parent 7f7a3d9a3b
commit b43f595207
14 changed files with 961 additions and 487 deletions

View File

@ -1,18 +1,18 @@
from typing import Union, NamedTuple
from typing import Union, NamedTuple, Dict
import numpy as np
from environments.factory.base.base_factory import BaseFactory
from environments.factory.base.objects import Agent, Action, Entity
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
from environments.factory.base.renderer import RenderEntity
from environments.helpers import Constants as c
from environments.helpers import Constants as c, Constants
from environments import helpers as h
CHARGE_ACTION = h.EnvActions.CHARGE
ITEM_DROP_OFF = 1
CHARGE_POD = 1
class BatteryProperties(NamedTuple):
@ -24,42 +24,18 @@ class BatteryProperties(NamedTuple):
multi_charge: bool = False
class Battery(object):
class Battery(EnvObject, BoundingMixin):
@property
def is_discharged(self):
return self.charge_level == 0
@property
def is_blocking_light(self):
return False
@property
def can_collide(self):
return False
@property
def name(self):
return f'{self.__class__.__name__}({self.agent.name})'
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
super().__init__()
self.agent = agent
self._pomdp_r = pomdp_r
self._level_shape = level_shape
if self._pomdp_r:
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
else:
self._array = np.zeros((1, *self._level_shape))
def __init__(self, initial_charge_level: float, *args, **kwargs):
super(Battery, self).__init__(*args, **kwargs)
self.charge_level = initial_charge_level
def as_array(self):
self._array[:] = c.FREE_CELL.value
self._array[0, 0] = self.charge_level
return self._array
def __repr__(self):
return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
def encoding(self):
return self.charge_level
def charge(self, amount) -> c:
if self.charge_level < 1:
@ -73,12 +49,10 @@ class Battery(object):
if self.charge_level != 0:
# noinspection PyTypeChecker
self.charge_level = max(0, amount + self.charge_level)
self._register.notify_change_to_value(self)
return c.VALID
else:
return c.NOT_VALID
def belongs_to_entity(self, entity):
return self.agent == entity
return c.NOT_VALID
def summarize_state(self, **kwargs):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
@ -86,7 +60,7 @@ class Battery(object):
return attr_dict
class BatteriesRegister(ObjectRegister):
class BatteriesRegister(EnvObjectRegister):
_accepted_objects = Battery
is_blocking_light = False
@ -98,16 +72,17 @@ class BatteriesRegister(ObjectRegister):
self.is_observable = True
def as_array(self):
# self._array[:] = c.FREE_CELL.value
# ToDO: Make this Lazy
self._array[:] = c.FREE_CELL.value
for inv_idx, battery in enumerate(self):
self._array[inv_idx] = battery.as_array()
return self._array
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
initial_charge_level)
for _, agent in enumerate(agents)]
self.register_additional_items(inventories)
batteries = [self._accepted_objects(pomdp_r, self._shape, agent,
initial_charge_level)
for _, agent in enumerate(agents)]
self.register_additional_items(batteries)
def idx_by_entity(self, entity):
try:
@ -135,7 +110,7 @@ class ChargePod(Entity):
@property
def encoding(self):
return ITEM_DROP_OFF
return CHARGE_POD
def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs):
@ -157,11 +132,12 @@ class ChargePod(Entity):
return summary
class ChargePods(EntityObjectRegister):
class ChargePods(EntityRegister):
_accepted_objects = ChargePod
def as_array(self):
@DeprecationWarning
def Xas_array(self):
self._array[:] = c.FREE_CELL.value
for item in self:
if item.pos != c.NO_POS.value:
@ -180,6 +156,16 @@ class BatteryFactory(BaseFactory):
self.btry_prop = btry_prop
super().__init__(*args, **kwargs)
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
additional_raw_observations = super()._additional_raw_observations(agent)
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()})
return additional_raw_observations
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
additional_observations = super()._additional_observations()
additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()})
return additional_observations
@property
def additional_entities(self):
super_entities = super().additional_entities