Rework of Observations and Entity Differentiation, lazy obs build by notification
This commit is contained in:
@ -1,18 +1,18 @@
|
||||
from typing import Union, NamedTuple
|
||||
from typing import Union, NamedTuple, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity
|
||||
from environments.factory.base.registers import EntityObjectRegister, ObjectRegister
|
||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as c
|
||||
from environments.helpers import Constants as c, Constants
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
CHARGE_ACTION = h.EnvActions.CHARGE
|
||||
ITEM_DROP_OFF = 1
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
@ -24,42 +24,18 @@ class BatteryProperties(NamedTuple):
|
||||
multi_charge: bool = False
|
||||
|
||||
|
||||
class Battery(object):
|
||||
class Battery(EnvObject, BoundingMixin):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
return self.charge_level == 0
|
||||
|
||||
@property
|
||||
def is_blocking_light(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def __init__(self, pomdp_r: int, level_shape: (int, int), agent: Agent, initial_charge_level: float):
|
||||
super().__init__()
|
||||
self.agent = agent
|
||||
self._pomdp_r = pomdp_r
|
||||
self._level_shape = level_shape
|
||||
if self._pomdp_r:
|
||||
self._array = np.zeros((1, pomdp_r * 2 + 1, pomdp_r * 2 + 1))
|
||||
else:
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
self._array[0, 0] = self.charge_level
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{self.agent.name}]({self.charge_level})'
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def charge(self, amount) -> c:
|
||||
if self.charge_level < 1:
|
||||
@ -73,12 +49,10 @@ class Battery(object):
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
self._register.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def belongs_to_entity(self, entity):
|
||||
return self.agent == entity
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
@ -86,7 +60,7 @@ class Battery(object):
|
||||
return attr_dict
|
||||
|
||||
|
||||
class BatteriesRegister(ObjectRegister):
|
||||
class BatteriesRegister(EnvObjectRegister):
|
||||
|
||||
_accepted_objects = Battery
|
||||
is_blocking_light = False
|
||||
@ -98,16 +72,17 @@ class BatteriesRegister(ObjectRegister):
|
||||
self.is_observable = True
|
||||
|
||||
def as_array(self):
|
||||
# self._array[:] = c.FREE_CELL.value
|
||||
# ToDO: Make this Lazy
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for inv_idx, battery in enumerate(self):
|
||||
self._array[inv_idx] = battery.as_array()
|
||||
return self._array
|
||||
|
||||
def spawn_batteries(self, agents, pomdp_r, initial_charge_level):
|
||||
inventories = [self._accepted_objects(pomdp_r, self._level_shape, agent,
|
||||
initial_charge_level)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(inventories)
|
||||
batteries = [self._accepted_objects(pomdp_r, self._shape, agent,
|
||||
initial_charge_level)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(batteries)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
@ -135,7 +110,7 @@ class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return ITEM_DROP_OFF
|
||||
return CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
@ -157,11 +132,12 @@ class ChargePod(Entity):
|
||||
return summary
|
||||
|
||||
|
||||
class ChargePods(EntityObjectRegister):
|
||||
class ChargePods(EntityRegister):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
def as_array(self):
|
||||
@DeprecationWarning
|
||||
def Xas_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS.value:
|
||||
@ -180,6 +156,16 @@ class BatteryFactory(BaseFactory):
|
||||
self.btry_prop = btry_prop
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _additional_raw_observations(self, agent) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super()._additional_raw_observations(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].by_entity(agent).as_array()})
|
||||
return additional_raw_observations
|
||||
|
||||
def _additional_observations(self) -> Dict[Constants, np.typing.ArrayLike]:
|
||||
additional_observations = super()._additional_observations()
|
||||
additional_observations.update({c.CHARGE_POD: self[c.CHARGE_POD].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
def additional_entities(self):
|
||||
super_entities = super().additional_entities
|
||||
|
Reference in New Issue
Block a user