Adjustments and Documentation, recording and new environments, refactoring
This commit is contained in:
@@ -2,8 +2,11 @@ def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, indiv
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from environments.factory.combined_factories import DirtItemFactory
|
||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory, RewardsDirt
|
||||
from environments.factory.factory_item import ItemFactory
|
||||
from environments.factory.additional.item.item_util import ItemProperties
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.dirt_util import DirtProperties
|
||||
from environments.factory.dirt_util import RewardsDirt
|
||||
from environments.utility_classes import AgentRenderOptions
|
||||
|
||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
||||
|
||||
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
50
environments/factory/additional/btry/btry_collections.py
Normal file
50
environments/factory/additional/btry/btry_collections.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from environments.factory.additional.btry.btry_objects import Battery, ChargePod
|
||||
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
||||
|
||||
|
||||
class BatteriesRegister(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = Battery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(batteries)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class ChargePods(EntityCollection):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(ChargePods, self).summarize_states(n_steps=n_steps)
|
||||
67
environments/factory/additional/btry/btry_objects.py
Normal file
67
environments/factory/additional/btry/btry_objects.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import BoundingMixin, EnvObject, Entity
|
||||
from environments.factory.additional.btry.btry_util import Constants as c
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
return self.charge_level == 0
|
||||
|
||||
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
summary = super().summarize_state(n_steps=n_steps)
|
||||
return summary
|
||||
else:
|
||||
{}
|
||||
30
environments/factory/additional/btry/btry_util.py
Normal file
30
environments/factory/additional/btry/btry_util.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class RewardsBtry(NamedTuple):
|
||||
CHARGE_VALID: float = 0.1
|
||||
CHARGE_FAIL: float = -0.1
|
||||
BATTERY_DISCHARGED: float = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
initial_charge: float = 0.8 #
|
||||
charge_rate: float = 0.4 #
|
||||
charge_locations: int = 20 #
|
||||
per_action_costs: Union[dict, float] = 0.02
|
||||
done_when_discharged: bool = False
|
||||
multi_charge: bool = False
|
||||
139
environments/factory/additional/btry/factory_battery.py
Normal file
139
environments/factory/additional/btry/factory_battery.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.additional.btry.btry_collections import BatteriesRegister, ChargePods
|
||||
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
|
||||
class BatteryFactory(BaseFactory):
|
||||
|
||||
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_btry: RewardsBtry = RewardsBtry(),
|
||||
**kwargs):
|
||||
if isinstance(btry_prop, dict):
|
||||
btry_prop = BatteryProperties(**btry_prop)
|
||||
if isinstance(rewards_btry, dict):
|
||||
rewards_btry = RewardsBtry(**rewards_btry)
|
||||
self.btry_prop = btry_prop
|
||||
self.rewards_dest = rewards_btry
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
def entities_hook(self):
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||
charge_pods = ChargePods.from_tiles(
|
||||
empty_tiles, self._level_shape,
|
||||
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
||||
multi_charge=self.btry_prop.multi_charge)
|
||||
)
|
||||
|
||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
|
||||
for agent in self[c.AGENT]:
|
||||
if isinstance(self.btry_prop.per_action_costs, dict):
|
||||
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
||||
else:
|
||||
energy_consumption = self.btry_prop.per_action_costs
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
return super_reward_info
|
||||
|
||||
def do_charge_action(self, agent) -> (dict, dict):
|
||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
if valid:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||
reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CHARGE:
|
||||
action_result = self.do_charge_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def reset_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).reset_hook()
|
||||
# There is Nothing to reset.
|
||||
return super_reward_info
|
||||
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||
if super_done:
|
||||
return super_done, super_dict
|
||||
else:
|
||||
if self.btry_prop.done_when_discharged:
|
||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||
super_dict.update(DISCHARGE_DONE=1)
|
||||
return btry_done, super_dict
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
return super_done, super_dict
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
reward_event_list = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_list.append({'value': self.rewards_dest.BATTERY_DISCHARGED,
|
||||
'reason': c.BATTERY_DISCHARGED,
|
||||
'info': info_dict}
|
||||
)
|
||||
else:
|
||||
# All Fine
|
||||
pass
|
||||
return reward_event_list
|
||||
|
||||
def render_assets_hook(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_assets_hook()
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
||||
@@ -1,12 +1,15 @@
|
||||
import random
|
||||
|
||||
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
||||
from environments.factory.factory_dest import DestFactory
|
||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||
from environments.factory.factory_item import ItemFactory
|
||||
|
||||
|
||||
# noinspection PyAbstractClass
|
||||
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||
from environments.factory.additional.btry.factory_battery import BatteryFactory
|
||||
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||
from environments.factory.additional.item.factory_item import ItemFactory
|
||||
|
||||
|
||||
class DirtItemFactory(ItemFactory, DirtFactory):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -24,6 +27,12 @@ class DirtDestItemFactory(ItemFactory, DirtFactory, DestFactory):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# noinspection PyAbstractClass
|
||||
class DestBatteryFactory(BatteryFactory, DestFactory):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||
|
||||
0
environments/factory/additional/dest/__init__.py
Normal file
0
environments/factory/additional/dest/__init__.py
Normal file
44
environments/factory/additional/dest/dest_collections.py
Normal file
44
environments/factory/additional/dest/dest_collections.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from environments.factory.base.registers import EntityCollection
|
||||
from environments.factory.additional.dest.dest_util import Constants as c
|
||||
from environments.factory.additional.dest.dest_enitites import Destination
|
||||
|
||||
|
||||
class Destinations(EntityCollection):
|
||||
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_blocking_light = False
|
||||
self.can_be_shadowed = False
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL
|
||||
# ToDo: Switch to new Style Array Put
|
||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
return super(Destinations, self).__repr__()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
def __repr__(self):
|
||||
return super(ReachedDestinations, self).__repr__()
|
||||
44
environments/factory/additional/dest/dest_enitites.py
Normal file
44
environments/factory/additional/dest/dest_enitites.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from environments.factory.base.objects import Entity, Agent
|
||||
from environments.factory.additional.dest.dest_util import Constants as c
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
|
||||
@property
|
||||
def any_agent_has_dwelled(self):
|
||||
return bool(len(self._per_agent_times))
|
||||
|
||||
@property
|
||||
def currently_dwelling_names(self):
|
||||
return self._per_agent_times.keys()
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.DESTINATION
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
def leave(self, agent: Agent):
|
||||
del self._per_agent_times[agent.name]
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
return self._per_agent_times[agent.name] < self.dwell_time
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
state_summary = super().summarize_state(n_steps=n_steps)
|
||||
state_summary.update(per_agent_times=self._per_agent_times)
|
||||
return state_summary
|
||||
41
environments/factory/additional/dest/dest_util.py
Normal file
41
environments/factory/additional/dest/dest_util.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Destination Env
|
||||
DEST = 'Destination'
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
DEST_REACHED = 'ReachedDestination'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class RewardsDest(NamedTuple):
|
||||
|
||||
WAIT_VALID: float = 0.1
|
||||
WAIT_FAIL: float = -0.1
|
||||
DEST_REACHED: float = 5.0
|
||||
|
||||
|
||||
class DestModeOptions(object):
|
||||
DONE = 'DONE'
|
||||
GROUPED = 'GROUPED'
|
||||
PER_DEST = 'PER_DEST'
|
||||
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
|
||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||
@@ -1,132 +1,19 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
from typing import List, Union, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.additional.dest.dest_collections import Destinations, ReachedDestinations
|
||||
from environments.factory.additional.dest.dest_enitites import Destination
|
||||
from environments.factory.additional.dest.dest_util import Constants, Actions, RewardsDest, DestModeOptions, \
|
||||
DestProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.base.registers import Entities, EntityCollection
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Destination Env
|
||||
DEST = 'Destination'
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
DEST_REACHED = 'ReachedDestination'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class RewardsDest(NamedTuple):
|
||||
|
||||
WAIT_VALID: float = 0.1
|
||||
WAIT_FAIL: float = -0.1
|
||||
DEST_REACHED: float = 5.0
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
|
||||
@property
|
||||
def any_agent_has_dwelled(self):
|
||||
return bool(len(self._per_agent_times))
|
||||
|
||||
@property
|
||||
def currently_dwelling_names(self):
|
||||
return self._per_agent_times.keys()
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.DESTINATION
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
def leave(self, agent: Agent):
|
||||
del self._per_agent_times[agent.name]
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
return self._per_agent_times[agent.name] < self.dwell_time
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
state_summary = super().summarize_state(n_steps=n_steps)
|
||||
state_summary.update(per_agent_times=self._per_agent_times)
|
||||
return state_summary
|
||||
|
||||
|
||||
class Destinations(EntityCollection):
|
||||
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_blocking_light = False
|
||||
self.can_be_shadowed = False
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL
|
||||
# ToDo: Switch to new Style Array Put
|
||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
super(Destinations, self).__repr__()
|
||||
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
|
||||
class DestModeOptions(object):
|
||||
DONE = 'DONE'
|
||||
GROUPED = 'GROUPED'
|
||||
PER_DEST = 'PER_DEST'
|
||||
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
|
||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
@@ -135,7 +22,7 @@ a = Actions
|
||||
class DestFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
|
||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||
env_seed=time.time_ns(), **kwargs):
|
||||
if isinstance(dest_prop, dict):
|
||||
dest_prop = DestProperties(**dest_prop)
|
||||
@@ -151,6 +38,7 @@ class DestFactory(BaseFactory):
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().actions_hook
|
||||
# If targets are considers reached after some time, agents need an action for that.
|
||||
if self.dest_prop.dwell_time:
|
||||
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||
return super_actions
|
||||
@@ -207,7 +95,7 @@ class DestFactory(BaseFactory):
|
||||
if destinations_to_spawn:
|
||||
n_dest_to_spawn = len(destinations_to_spawn)
|
||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].add_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
@@ -229,9 +117,10 @@ class DestFactory(BaseFactory):
|
||||
super_reward_info = super().step_hook()
|
||||
for key, val in self._dest_spawn_timer.items():
|
||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||
|
||||
for dest in list(self[c.DEST].values()):
|
||||
if dest.is_considered_reached:
|
||||
dest.change_parent_collection(self[c.DEST])
|
||||
dest.change_parent_collection(self[c.DEST_REACHED])
|
||||
self._dest_spawn_timer[dest.name] = 0
|
||||
self.print(f'{dest.name} is reached now, removing...')
|
||||
else:
|
||||
@@ -251,18 +140,19 @@ class DestFactory(BaseFactory):
|
||||
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward_event_dict = super().per_agent_reward_hook(agent)
|
||||
reward_event_list = super().per_agent_reward_hook(agent)
|
||||
if len(self[c.DEST_REACHED]):
|
||||
for reached_dest in list(self[c.DEST_REACHED]):
|
||||
if agent.pos == reached_dest.pos:
|
||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
|
||||
'info': info_dict}})
|
||||
return reward_event_dict
|
||||
reward_event_list.append({'value': self.rewards_dest.DEST_REACHED,
|
||||
'reason': c.DEST_REACHED,
|
||||
'info': info_dict})
|
||||
return reward_event_list
|
||||
|
||||
def render_assets_hook(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
||||
0
environments/factory/additional/dirt/__init__.py
Normal file
0
environments/factory/additional/dirt/__init__.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import EntityCollection
|
||||
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||
|
||||
|
||||
class DirtRegister(EntityCollection):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return sum([dirt.amount for dirt in self])
|
||||
|
||||
@property
|
||||
def dirt_properties(self):
|
||||
return self._dirt_properties
|
||||
|
||||
def __init__(self, dirt_properties, *args):
|
||||
super(DirtRegister, self).__init__(*args)
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtRegister, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
||||
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from environments.factory.base.objects import Entity
|
||||
|
||||
|
||||
class Dirt(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differntly
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, amount=None, **kwargs):
|
||||
super(Dirt, self).__init__(*args, **kwargs)
|
||||
self._amount = amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
30
environments/factory/additional/dirt/dirt_util.py
Normal file
30
environments/factory/additional/dirt/dirt_util.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
|
||||
class RewardsDirt(NamedTuple):
|
||||
CLEAN_UP_VALID: float = 0.5
|
||||
CLEAN_UP_FAIL: float = -0.1
|
||||
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||
|
||||
|
||||
class DirtProperties(NamedTuple):
|
||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||
done_when_clean: bool = True
|
||||
@@ -1,111 +1,22 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
from typing import List, Union, Dict
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||
from environments.factory.base.registers import Entities, EntityCollection
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.utility_classes import ObservationProperties
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
|
||||
class RewardsDirt(NamedTuple):
|
||||
CLEAN_UP_VALID: float = 0.5
|
||||
CLEAN_UP_FAIL: float = -0.1
|
||||
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||
|
||||
|
||||
class DirtProperties(NamedTuple):
|
||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||
done_when_clean: bool = True
|
||||
|
||||
|
||||
class Dirt(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differntly
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, amount=None, **kwargs):
|
||||
super(Dirt, self).__init__(*args, **kwargs)
|
||||
self._amount = amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
|
||||
|
||||
class DirtRegister(EntityCollection):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return sum([dirt.amount for dirt in self])
|
||||
|
||||
@property
|
||||
def dirt_properties(self):
|
||||
return self._dirt_properties
|
||||
|
||||
def __init__(self, dirt_properties, *args):
|
||||
super(DirtRegister, self).__init__(*args)
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtRegister, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
||||
|
||||
|
||||
def softmax(x):
|
||||
"""Compute softmax values for each sets of scores in x."""
|
||||
e_x = np.exp(x - np.max(x))
|
||||
@@ -200,8 +111,8 @@ class DirtFactory(BaseFactory):
|
||||
super_reward_info = super().step_hook()
|
||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||
for agent in self[c.AGENT]:
|
||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||
if self._actions.is_moving_action(agent.temp_action):
|
||||
if agent.step_result['action_valid'] and agent.last_pos != c.NO_POS:
|
||||
if self._actions.is_moving_action(agent.step_result['action_name']):
|
||||
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||
@@ -248,8 +159,8 @@ class DirtFactory(BaseFactory):
|
||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def gather_additional_info(self, agent: Agent) -> dict:
|
||||
event_reward_dict = super().per_agent_reward_hook(agent)
|
||||
def post_step_hook(self) -> List[Dict[str, int]]:
|
||||
super_post_step = super(DirtFactory, self).post_step_hook()
|
||||
info_dict = dict()
|
||||
|
||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||
@@ -264,8 +175,8 @@ class DirtFactory(BaseFactory):
|
||||
info_dict.update(dirt_amount=current_dirt_amount)
|
||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||
|
||||
event_reward_dict.update({'info': info_dict})
|
||||
return event_reward_dict
|
||||
super_post_step.append(info_dict)
|
||||
return super_post_step
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -304,7 +215,6 @@ if __name__ == '__main__':
|
||||
# inject_agents=[TSPDirtAgent],
|
||||
)
|
||||
|
||||
factory.save_params(Path('rewards_param'))
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
0
environments/factory/additional/item/__init__.py
Normal file
0
environments/factory/additional/item/__init__.py
Normal file
@@ -1,179 +1,16 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
from typing import List, Union, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.additional.item.item_collections import ItemRegister, Inventories, DropOffLocations
|
||||
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||
from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class RewardsItem(NamedTuple):
|
||||
DROP_OFF_VALID: float = 0.1
|
||||
DROP_OFF_FAIL: float = -0.1
|
||||
PICK_UP_FAIL: float = -0.1
|
||||
PICK_UP_VALID: float = 0.1
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._auto_despawn = -1
|
||||
|
||||
@property
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
return 1
|
||||
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
assert self._collection.__class__.__name__ != ItemRegister.__class__
|
||||
self._tile = no_pos_tile
|
||||
|
||||
|
||||
class ItemRegister(EntityCollection):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.add_additional_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundEnvObjCollection):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.capacity = capacity
|
||||
|
||||
def as_array(self):
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
return super(Inventory, self).as_array()
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
|
||||
class Inventories(ObjectCollection):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
|
||||
def __init__(self, obs_shape, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
def as_array(self):
|
||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(inventories)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return Constants.ITEM_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item):
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return c.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
|
||||
|
||||
class DropOffLocations(EntityCollection):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
87
environments/factory/additional/item/item_collections.py
Normal file
87
environments/factory/additional/item/item_collections.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Floor, Agent
|
||||
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
||||
|
||||
|
||||
class ItemRegister(EntityCollection):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.add_additional_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundEnvObjCollection):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.capacity = capacity
|
||||
|
||||
def as_array(self):
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
return super(Inventory, self).as_array()
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
|
||||
class Inventories(ObjectCollection):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
|
||||
def __init__(self, obs_shape, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
def as_array(self):
|
||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(inventories)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||
|
||||
|
||||
class DropOffLocations(EntityCollection):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
61
environments/factory/additional/item/item_entities.py
Normal file
61
environments/factory/additional/item/item_entities.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from collections import deque
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.additional.item.item_util import Constants
|
||||
from environments.factory.base.objects import Entity
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._auto_despawn = -1
|
||||
|
||||
@property
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
return 1
|
||||
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
self._tile = no_pos_tile
|
||||
|
||||
def summarize_state(self, **_) -> dict:
|
||||
super_summarization = super(Item, self).summarize_state()
|
||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||
return super_summarization
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return Constants.ITEM_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item):
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return c.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||
return Constants.VALID
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
31
environments/factory/additional/item/item_util.py
Normal file
31
environments/factory/additional/item/item_util.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class RewardsItem(NamedTuple):
|
||||
DROP_OFF_VALID: float = 0.1
|
||||
DROP_OFF_FAIL: float = -0.1
|
||||
PICK_UP_FAIL: float = -0.1
|
||||
PICK_UP_VALID: float = 0.1
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
||||
@@ -180,6 +180,7 @@ class BaseFactory(gym.Env):
|
||||
self._entities.add_additional_items({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
# TODO: Move this to Agent init, so that agents can have individual action sets.
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.actions_hook:
|
||||
self._actions.add_additional_items(additional_actions)
|
||||
@@ -308,7 +309,8 @@ class BaseFactory(gym.Env):
|
||||
info.update(self._summarize_state())
|
||||
|
||||
# Post step Hook for later use
|
||||
info.update(self.post_step_hook())
|
||||
for post_step_info in self.post_step_hook():
|
||||
info.update(post_step_info)
|
||||
|
||||
obs, _ = self._build_observations()
|
||||
|
||||
@@ -367,14 +369,16 @@ class BaseFactory(gym.Env):
|
||||
agent_obs = global_agent_obs.copy()
|
||||
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||
else:
|
||||
agent_obs = global_agent_obs
|
||||
agent_obs = global_agent_obs.copy()
|
||||
else:
|
||||
# agent_obs == None!!!!!
|
||||
agent_obs = global_agent_obs
|
||||
|
||||
# Build Level Observations
|
||||
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||
assert agent_obs is not None
|
||||
lvl_obs = lvl_obs.copy()
|
||||
lvl_obs += global_agent_obs
|
||||
lvl_obs += agent_obs
|
||||
|
||||
obs_dict[c.WALLS] = lvl_obs
|
||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||
@@ -600,7 +604,9 @@ class BaseFactory(gym.Env):
|
||||
for reward in agent.step_result['rewards']:
|
||||
combined_info_dict.update(reward['info'])
|
||||
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = dict(combined_info_dict)
|
||||
|
||||
combined_info_dict.update(info)
|
||||
|
||||
global_reward_sum = sum(global_env_rewards)
|
||||
@@ -616,9 +622,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def start_recording(self):
|
||||
self._record_episodes = True
|
||||
return self._record_episodes
|
||||
|
||||
def stop_recording(self):
|
||||
self._record_episodes = False
|
||||
return not self._record_episodes
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
@@ -719,12 +727,12 @@ class BaseFactory(gym.Env):
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
return {}
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def post_step_hook(self) -> dict:
|
||||
return {}
|
||||
def post_step_hook(self) -> List[dict]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
|
||||
@@ -119,7 +119,6 @@ class Entity(EnvObject):
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||
# With Position in Env
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
|
||||
@@ -117,7 +117,7 @@ class EnvObjectCollection(ObjectCollection):
|
||||
return self._array
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
|
||||
|
||||
def notify_change_to_free(self, env_object: EnvObject):
|
||||
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||
|
||||
@@ -1,273 +0,0 @@
|
||||
from typing import Union, NamedTuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class RewardsBtry(NamedTuple):
|
||||
CHARGE_VALID: float = 0.1
|
||||
CHARGE_FAIL: float = -0.1
|
||||
BATTERY_DISCHARGED: float = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
initial_charge: float = 0.8 #
|
||||
charge_rate: float = 0.4 #
|
||||
charge_locations: int = 20 #
|
||||
per_action_costs: Union[dict, float] = 0.02
|
||||
done_when_discharged = False
|
||||
multi_charge: bool = False
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
return self.charge_level == 0
|
||||
|
||||
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
|
||||
class BatteriesRegister(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = Battery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(batteries)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
summary = super().summarize_state(n_steps=n_steps)
|
||||
return summary
|
||||
|
||||
|
||||
class ChargePods(EntityCollection):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
||||
|
||||
|
||||
class BatteryFactory(BaseFactory):
|
||||
|
||||
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
||||
**kwargs):
|
||||
if isinstance(btry_prop, dict):
|
||||
btry_prop = BatteryProperties(**btry_prop)
|
||||
if isinstance(rewards_dest, dict):
|
||||
rewards_dest = BatteryProperties(**rewards_dest)
|
||||
self.btry_prop = btry_prop
|
||||
self.rewards_dest = rewards_dest
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
def entities_hook(self):
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||
charge_pods = ChargePods.from_tiles(
|
||||
empty_tiles, self._level_shape,
|
||||
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
||||
multi_charge=self.btry_prop.multi_charge)
|
||||
)
|
||||
|
||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
|
||||
for agent in self[c.AGENT]:
|
||||
if isinstance(self.btry_prop.per_action_costs, dict):
|
||||
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
||||
else:
|
||||
energy_consumption = self.btry_prop.per_action_costs
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
return super_reward_info
|
||||
|
||||
def do_charge_action(self, agent) -> (dict, dict):
|
||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
if valid:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||
reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CHARGE:
|
||||
action_result = self.do_charge_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def reset_hook(self) -> None:
|
||||
# There is Nothing to reset.
|
||||
pass
|
||||
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||
if super_done:
|
||||
return super_done, super_dict
|
||||
else:
|
||||
if self.btry_prop.done_when_discharged:
|
||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||
super_dict.update(DISCHARGE_DONE=1)
|
||||
return btry_done, super_dict
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
||||
'info': info_dict}}
|
||||
)
|
||||
else:
|
||||
# All Fine
|
||||
pass
|
||||
return reward_event_dict
|
||||
|
||||
def render_assets_hook(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_assets_hook()
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
||||
@@ -3,7 +3,9 @@ from typing import Dict, List, Union
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
||||
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||
|
||||
@@ -28,7 +30,6 @@ class StationaryMachinesDirtFactory(DirtFactory):
|
||||
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook()
|
||||
|
||||
return super_entities
|
||||
|
||||
def reset_hook(self) -> None:
|
||||
@@ -48,8 +49,8 @@ class StationaryMachinesDirtFactory(DirtFactory):
|
||||
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
return super_per_agent_raw_observations
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
pass
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
return super(StationaryMachinesDirtFactory, self).per_agent_reward_hook(agent)
|
||||
|
||||
def pre_step_hook(self) -> None:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user