Adjustments and Documentation, recording and new environments, refactoring
This commit is contained in:
parent
e7461d7dcf
commit
6a24e7b518
@ -2,8 +2,11 @@ def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, indiv
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
from environments.factory.factory_item import ItemFactory
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory, RewardsDirt
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.dirt_util import DirtProperties
|
||||||
|
from environments.factory.dirt_util import RewardsDirt
|
||||||
from environments.utility_classes import AgentRenderOptions
|
from environments.utility_classes import AgentRenderOptions
|
||||||
|
|
||||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
||||||
|
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
50
environments/factory/additional/btry/btry_collections.py
Normal file
50
environments/factory/additional/btry/btry_collections.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from environments.factory.additional.btry.btry_objects import Battery, ChargePod
|
||||||
|
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
||||||
|
|
||||||
|
|
||||||
|
class BatteriesRegister(EnvObjectCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Battery
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||||
|
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
|
self.is_observable = True
|
||||||
|
|
||||||
|
def spawn_batteries(self, agents, initial_charge_level):
|
||||||
|
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||||
|
self.add_additional_items(batteries)
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
# as dict with additional nesting
|
||||||
|
# return dict(items=super(Inventories, cls).summarize_states())
|
||||||
|
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||||
|
|
||||||
|
# Todo Move this to Mixin!
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def as_array_by_entity(self, entity):
|
||||||
|
return self._array[self.idx_by_entity(entity)]
|
||||||
|
|
||||||
|
|
||||||
|
class ChargePods(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = ChargePod
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
super(ChargePods, self).__repr__()
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
# as dict with additional nesting
|
||||||
|
# return dict(items=super(Inventories, cls).summarize_states())
|
||||||
|
return super(ChargePods, self).summarize_states(n_steps=n_steps)
|
67
environments/factory/additional/btry/btry_objects.py
Normal file
67
environments/factory/additional/btry/btry_objects.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.base.objects import BoundingMixin, EnvObject, Entity
|
||||||
|
from environments.factory.additional.btry.btry_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class Battery(BoundingMixin, EnvObject):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_discharged(self):
|
||||||
|
return self.charge_level == 0
|
||||||
|
|
||||||
|
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||||
|
super(Battery, self).__init__(*args, **kwargs)
|
||||||
|
self.charge_level = initial_charge_level
|
||||||
|
|
||||||
|
def encoding(self):
|
||||||
|
return self.charge_level
|
||||||
|
|
||||||
|
def do_charge_action(self, amount):
|
||||||
|
if self.charge_level < 1:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.charge_level = min(1, amount + self.charge_level)
|
||||||
|
return c.VALID
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
|
||||||
|
def decharge(self, amount) -> c:
|
||||||
|
if self.charge_level != 0:
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
self.charge_level = max(0, amount + self.charge_level)
|
||||||
|
self._collection.notify_change_to_value(self)
|
||||||
|
return c.VALID
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
|
||||||
|
def summarize_state(self, **_):
|
||||||
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
|
attr_dict.update(dict(name=self.name))
|
||||||
|
return attr_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ChargePod(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.CHARGE_POD
|
||||||
|
|
||||||
|
def __init__(self, *args, charge_rate: float = 0.4,
|
||||||
|
multi_charge: bool = False, **kwargs):
|
||||||
|
super(ChargePod, self).__init__(*args, **kwargs)
|
||||||
|
self.charge_rate = charge_rate
|
||||||
|
self.multi_charge = multi_charge
|
||||||
|
|
||||||
|
def charge_battery(self, battery: Battery):
|
||||||
|
if battery.charge_level == 1.0:
|
||||||
|
return c.NOT_VALID
|
||||||
|
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||||
|
return c.NOT_VALID
|
||||||
|
valid = battery.do_charge_action(self.charge_rate)
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
|
if n_steps == h.STEPS_START:
|
||||||
|
summary = super().summarize_state(n_steps=n_steps)
|
||||||
|
return summary
|
||||||
|
else:
|
||||||
|
{}
|
30
environments/factory/additional/btry/btry_util.py
Normal file
30
environments/factory/additional/btry/btry_util.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import NamedTuple, Union
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
# Battery Env
|
||||||
|
CHARGE_PODS = 'Charge_Pod'
|
||||||
|
BATTERIES = 'BATTERIES'
|
||||||
|
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||||
|
CHARGE_POD = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CHARGE = 'do_charge_action'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsBtry(NamedTuple):
|
||||||
|
CHARGE_VALID: float = 0.1
|
||||||
|
CHARGE_FAIL: float = -0.1
|
||||||
|
BATTERY_DISCHARGED: float = -1.0
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryProperties(NamedTuple):
|
||||||
|
initial_charge: float = 0.8 #
|
||||||
|
charge_rate: float = 0.4 #
|
||||||
|
charge_locations: int = 20 #
|
||||||
|
per_action_costs: Union[dict, float] = 0.02
|
||||||
|
done_when_discharged: bool = False
|
||||||
|
multi_charge: bool = False
|
139
environments/factory/additional/btry/factory_battery.py
Normal file
139
environments/factory/additional/btry/factory_battery.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.additional.btry.btry_collections import BatteriesRegister, ChargePods
|
||||||
|
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.factory.base.objects import Agent, Action
|
||||||
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
c = Constants
|
||||||
|
a = Actions
|
||||||
|
|
||||||
|
|
||||||
|
class BatteryFactory(BaseFactory):
|
||||||
|
|
||||||
|
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_btry: RewardsBtry = RewardsBtry(),
|
||||||
|
**kwargs):
|
||||||
|
if isinstance(btry_prop, dict):
|
||||||
|
btry_prop = BatteryProperties(**btry_prop)
|
||||||
|
if isinstance(rewards_btry, dict):
|
||||||
|
rewards_btry = RewardsBtry(**rewards_btry)
|
||||||
|
self.btry_prop = btry_prop
|
||||||
|
self.rewards_dest = rewards_btry
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
|
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||||
|
return additional_raw_observations
|
||||||
|
|
||||||
|
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
additional_observations = super().observations_hook()
|
||||||
|
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||||
|
return additional_observations
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entities_hook(self):
|
||||||
|
super_entities = super().entities_hook
|
||||||
|
|
||||||
|
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||||
|
charge_pods = ChargePods.from_tiles(
|
||||||
|
empty_tiles, self._level_shape,
|
||||||
|
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
||||||
|
multi_charge=self.btry_prop.multi_charge)
|
||||||
|
)
|
||||||
|
|
||||||
|
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||||
|
)
|
||||||
|
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||||
|
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||||
|
return super_entities
|
||||||
|
|
||||||
|
def step_hook(self) -> (List[dict], dict):
|
||||||
|
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||||
|
|
||||||
|
# Decharge
|
||||||
|
batteries = self[c.BATTERIES]
|
||||||
|
|
||||||
|
for agent in self[c.AGENT]:
|
||||||
|
if isinstance(self.btry_prop.per_action_costs, dict):
|
||||||
|
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
||||||
|
else:
|
||||||
|
energy_consumption = self.btry_prop.per_action_costs
|
||||||
|
|
||||||
|
batteries.by_entity(agent).decharge(energy_consumption)
|
||||||
|
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def do_charge_action(self, agent) -> (dict, dict):
|
||||||
|
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||||
|
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||||
|
if valid:
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||||
|
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||||
|
else:
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
|
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||||
|
else:
|
||||||
|
valid = c.NOT_VALID
|
||||||
|
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||||
|
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||||
|
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||||
|
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||||
|
reason=a.CHARGE, info=info_dict)
|
||||||
|
return valid, reward
|
||||||
|
|
||||||
|
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||||
|
action_result = super().do_additional_actions(agent, action)
|
||||||
|
if action_result is None:
|
||||||
|
if action == a.CHARGE:
|
||||||
|
action_result = self.do_charge_action(agent)
|
||||||
|
return action_result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return action_result
|
||||||
|
pass
|
||||||
|
|
||||||
|
def reset_hook(self) -> (List[dict], dict):
|
||||||
|
super_reward_info = super(BatteryFactory, self).reset_hook()
|
||||||
|
# There is Nothing to reset.
|
||||||
|
return super_reward_info
|
||||||
|
|
||||||
|
def check_additional_done(self) -> (bool, dict):
|
||||||
|
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||||
|
if super_done:
|
||||||
|
return super_done, super_dict
|
||||||
|
else:
|
||||||
|
if self.btry_prop.done_when_discharged:
|
||||||
|
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||||
|
super_dict.update(DISCHARGE_DONE=1)
|
||||||
|
return btry_done, super_dict
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
return super_done, super_dict
|
||||||
|
|
||||||
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
|
reward_event_list = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||||
|
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||||
|
self.print(f'{agent.name} Battery is discharged!')
|
||||||
|
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||||
|
reward_event_list.append({'value': self.rewards_dest.BATTERY_DISCHARGED,
|
||||||
|
'reason': c.BATTERY_DISCHARGED,
|
||||||
|
'info': info_dict}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# All Fine
|
||||||
|
pass
|
||||||
|
return reward_event_list
|
||||||
|
|
||||||
|
def render_assets_hook(self):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
additional_assets = super().render_assets_hook()
|
||||||
|
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||||
|
additional_assets.extend(charge_pods)
|
||||||
|
return additional_assets
|
@ -1,12 +1,15 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
|
||||||
from environments.factory.factory_dest import DestFactory
|
|
||||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
|
||||||
from environments.factory.factory_item import ItemFactory
|
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyAbstractClass
|
# noinspection PyAbstractClass
|
||||||
|
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||||
|
from environments.factory.additional.btry.factory_battery import BatteryFactory
|
||||||
|
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||||
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.additional.item.factory_item import ItemFactory
|
||||||
|
|
||||||
|
|
||||||
class DirtItemFactory(ItemFactory, DirtFactory):
|
class DirtItemFactory(ItemFactory, DirtFactory):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -24,6 +27,12 @@ class DirtDestItemFactory(ItemFactory, DirtFactory, DestFactory):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAbstractClass
|
||||||
|
class DestBatteryFactory(BatteryFactory, DestFactory):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||||
|
|
0
environments/factory/additional/dest/__init__.py
Normal file
0
environments/factory/additional/dest/__init__.py
Normal file
44
environments/factory/additional/dest/dest_collections.py
Normal file
44
environments/factory/additional/dest/dest_collections.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from environments.factory.base.registers import EntityCollection
|
||||||
|
from environments.factory.additional.dest.dest_util import Constants as c
|
||||||
|
from environments.factory.additional.dest.dest_enitites import Destination
|
||||||
|
|
||||||
|
|
||||||
|
class Destinations(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Destination
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_blocking_light = False
|
||||||
|
self.can_be_shadowed = False
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
self._array[:] = c.FREE_CELL
|
||||||
|
# ToDo: Switch to new Style Array Put
|
||||||
|
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||||
|
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||||
|
for item in self:
|
||||||
|
if item.pos != c.NO_POS:
|
||||||
|
self._array[0, item.x, item.y] = item.encoding
|
||||||
|
return self._array
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return super(Destinations, self).__repr__()
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class ReachedDestinations(Destinations):
|
||||||
|
_accepted_objects = Destination
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||||
|
self.can_be_shadowed = False
|
||||||
|
self.is_blocking_light = False
|
||||||
|
|
||||||
|
def summarize_states(self, n_steps=None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return super(ReachedDestinations, self).__repr__()
|
44
environments/factory/additional/dest/dest_enitites.py
Normal file
44
environments/factory/additional/dest/dest_enitites.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Entity, Agent
|
||||||
|
from environments.factory.additional.dest.dest_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class Destination(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def any_agent_has_dwelled(self):
|
||||||
|
return bool(len(self._per_agent_times))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def currently_dwelling_names(self):
|
||||||
|
return self._per_agent_times.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return c.DESTINATION
|
||||||
|
|
||||||
|
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||||
|
super(Destination, self).__init__(*args, **kwargs)
|
||||||
|
self.dwell_time = dwell_time
|
||||||
|
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||||
|
|
||||||
|
def do_wait_action(self, agent: Agent):
|
||||||
|
self._per_agent_times[agent.name] -= 1
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def leave(self, agent: Agent):
|
||||||
|
del self._per_agent_times[agent.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_considered_reached(self):
|
||||||
|
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||||
|
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||||
|
|
||||||
|
def agent_is_dwelling(self, agent: Agent):
|
||||||
|
return self._per_agent_times[agent.name] < self.dwell_time
|
||||||
|
|
||||||
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
|
state_summary = super().summarize_state(n_steps=n_steps)
|
||||||
|
state_summary.update(per_agent_times=self._per_agent_times)
|
||||||
|
return state_summary
|
41
environments/factory/additional/dest/dest_util.py
Normal file
41
environments/factory/additional/dest/dest_util.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
# Destination Env
|
||||||
|
DEST = 'Destination'
|
||||||
|
DESTINATION = 1
|
||||||
|
DESTINATION_DONE = 0.5
|
||||||
|
DEST_REACHED = 'ReachedDestination'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
WAIT_ON_DEST = 'WAIT'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsDest(NamedTuple):
|
||||||
|
|
||||||
|
WAIT_VALID: float = 0.1
|
||||||
|
WAIT_FAIL: float = -0.1
|
||||||
|
DEST_REACHED: float = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
class DestModeOptions(object):
|
||||||
|
DONE = 'DONE'
|
||||||
|
GROUPED = 'GROUPED'
|
||||||
|
PER_DEST = 'PER_DEST'
|
||||||
|
|
||||||
|
|
||||||
|
class DestProperties(NamedTuple):
|
||||||
|
n_dests: int = 1 # How many destinations are there
|
||||||
|
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency: int = 0
|
||||||
|
spawn_in_other_zone: bool = True #
|
||||||
|
spawn_mode: str = DestModeOptions.DONE
|
||||||
|
|
||||||
|
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||||
|
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||||
|
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||||
|
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
@ -1,132 +1,19 @@
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Union, NamedTuple, Dict
|
from typing import List, Union, Dict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from environments.factory.additional.dest.dest_collections import Destinations, ReachedDestinations
|
||||||
|
from environments.factory.additional.dest.dest_enitites import Destination
|
||||||
|
from environments.factory.additional.dest.dest_util import Constants, Actions, RewardsDest, DestModeOptions, \
|
||||||
|
DestProperties
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.factory.base.objects import Agent, Action
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.factory.base.registers import Entities
|
||||||
from environments.factory.base.objects import Agent, Entity, Action
|
|
||||||
from environments.factory.base.registers import Entities, EntityCollection
|
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
|
||||||
class Constants(BaseConstants):
|
|
||||||
# Destination Env
|
|
||||||
DEST = 'Destination'
|
|
||||||
DESTINATION = 1
|
|
||||||
DESTINATION_DONE = 0.5
|
|
||||||
DEST_REACHED = 'ReachedDestination'
|
|
||||||
|
|
||||||
|
|
||||||
class Actions(BaseActions):
|
|
||||||
WAIT_ON_DEST = 'WAIT'
|
|
||||||
|
|
||||||
|
|
||||||
class RewardsDest(NamedTuple):
|
|
||||||
|
|
||||||
WAIT_VALID: float = 0.1
|
|
||||||
WAIT_FAIL: float = -0.1
|
|
||||||
DEST_REACHED: float = 5.0
|
|
||||||
|
|
||||||
|
|
||||||
class Destination(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def any_agent_has_dwelled(self):
|
|
||||||
return bool(len(self._per_agent_times))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def currently_dwelling_names(self):
|
|
||||||
return self._per_agent_times.keys()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return c.DESTINATION
|
|
||||||
|
|
||||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
|
||||||
super(Destination, self).__init__(*args, **kwargs)
|
|
||||||
self.dwell_time = dwell_time
|
|
||||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
|
||||||
|
|
||||||
def do_wait_action(self, agent: Agent):
|
|
||||||
self._per_agent_times[agent.name] -= 1
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
def leave(self, agent: Agent):
|
|
||||||
del self._per_agent_times[agent.name]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_considered_reached(self):
|
|
||||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
|
||||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
|
||||||
|
|
||||||
def agent_is_dwelling(self, agent: Agent):
|
|
||||||
return self._per_agent_times[agent.name] < self.dwell_time
|
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
state_summary = super().summarize_state(n_steps=n_steps)
|
|
||||||
state_summary.update(per_agent_times=self._per_agent_times)
|
|
||||||
return state_summary
|
|
||||||
|
|
||||||
|
|
||||||
class Destinations(EntityCollection):
|
|
||||||
|
|
||||||
_accepted_objects = Destination
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.is_blocking_light = False
|
|
||||||
self.can_be_shadowed = False
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
self._array[:] = c.FREE_CELL
|
|
||||||
# ToDo: Switch to new Style Array Put
|
|
||||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
|
||||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
|
||||||
for item in self:
|
|
||||||
if item.pos != c.NO_POS:
|
|
||||||
self._array[0, item.x, item.y] = item.encoding
|
|
||||||
return self._array
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(Destinations, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class ReachedDestinations(Destinations):
|
|
||||||
_accepted_objects = Destination
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
|
||||||
self.can_be_shadowed = False
|
|
||||||
self.is_blocking_light = False
|
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class DestModeOptions(object):
|
|
||||||
DONE = 'DONE'
|
|
||||||
GROUPED = 'GROUPED'
|
|
||||||
PER_DEST = 'PER_DEST'
|
|
||||||
|
|
||||||
|
|
||||||
class DestProperties(NamedTuple):
|
|
||||||
n_dests: int = 1 # How many destinations are there
|
|
||||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
|
||||||
spawn_frequency: int = 0
|
|
||||||
spawn_in_other_zone: bool = True #
|
|
||||||
spawn_mode: str = DestModeOptions.DONE
|
|
||||||
|
|
||||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
|
||||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
|
||||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
|
||||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
|
||||||
|
|
||||||
|
|
||||||
c = Constants
|
c = Constants
|
||||||
a = Actions
|
a = Actions
|
||||||
|
|
||||||
@ -151,6 +38,7 @@ class DestFactory(BaseFactory):
|
|||||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
super_actions = super().actions_hook
|
super_actions = super().actions_hook
|
||||||
|
# If targets are considers reached after some time, agents need an action for that.
|
||||||
if self.dest_prop.dwell_time:
|
if self.dest_prop.dwell_time:
|
||||||
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||||
return super_actions
|
return super_actions
|
||||||
@ -207,7 +95,7 @@ class DestFactory(BaseFactory):
|
|||||||
if destinations_to_spawn:
|
if destinations_to_spawn:
|
||||||
n_dest_to_spawn = len(destinations_to_spawn)
|
n_dest_to_spawn = len(destinations_to_spawn)
|
||||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||||
self[c.DEST].add_additional_items(destinations)
|
self[c.DEST].add_additional_items(destinations)
|
||||||
for dest in destinations_to_spawn:
|
for dest in destinations_to_spawn:
|
||||||
del self._dest_spawn_timer[dest]
|
del self._dest_spawn_timer[dest]
|
||||||
@ -229,9 +117,10 @@ class DestFactory(BaseFactory):
|
|||||||
super_reward_info = super().step_hook()
|
super_reward_info = super().step_hook()
|
||||||
for key, val in self._dest_spawn_timer.items():
|
for key, val in self._dest_spawn_timer.items():
|
||||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||||
|
|
||||||
for dest in list(self[c.DEST].values()):
|
for dest in list(self[c.DEST].values()):
|
||||||
if dest.is_considered_reached:
|
if dest.is_considered_reached:
|
||||||
dest.change_parent_collection(self[c.DEST])
|
dest.change_parent_collection(self[c.DEST_REACHED])
|
||||||
self._dest_spawn_timer[dest.name] = 0
|
self._dest_spawn_timer[dest.name] = 0
|
||||||
self.print(f'{dest.name} is reached now, removing...')
|
self.print(f'{dest.name} is reached now, removing...')
|
||||||
else:
|
else:
|
||||||
@ -251,18 +140,19 @@ class DestFactory(BaseFactory):
|
|||||||
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||||
return additional_observations
|
return additional_observations
|
||||||
|
|
||||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
reward_event_dict = super().per_agent_reward_hook(agent)
|
reward_event_list = super().per_agent_reward_hook(agent)
|
||||||
if len(self[c.DEST_REACHED]):
|
if len(self[c.DEST_REACHED]):
|
||||||
for reached_dest in list(self[c.DEST_REACHED]):
|
for reached_dest in list(self[c.DEST_REACHED]):
|
||||||
if agent.pos == reached_dest.pos:
|
if agent.pos == reached_dest.pos:
|
||||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||||
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
|
reward_event_list.append({'value': self.rewards_dest.DEST_REACHED,
|
||||||
'info': info_dict}})
|
'reason': c.DEST_REACHED,
|
||||||
return reward_event_dict
|
'info': info_dict})
|
||||||
|
return reward_event_list
|
||||||
|
|
||||||
def render_assets_hook(self, mode='human'):
|
def render_assets_hook(self, mode='human'):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
0
environments/factory/additional/dirt/__init__.py
Normal file
0
environments/factory/additional/dirt/__init__.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||||
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
|
from environments.factory.base.objects import Floor
|
||||||
|
from environments.factory.base.registers import EntityCollection
|
||||||
|
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||||
|
|
||||||
|
|
||||||
|
class DirtRegister(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Dirt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def amount(self):
|
||||||
|
return sum([dirt.amount for dirt in self])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dirt_properties(self):
|
||||||
|
return self._dirt_properties
|
||||||
|
|
||||||
|
def __init__(self, dirt_properties, *args):
|
||||||
|
super(DirtRegister, self).__init__(*args)
|
||||||
|
self._dirt_properties: DirtProperties = dirt_properties
|
||||||
|
|
||||||
|
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||||
|
if isinstance(then_dirty_tiles, Floor):
|
||||||
|
then_dirty_tiles = [then_dirty_tiles]
|
||||||
|
for tile in then_dirty_tiles:
|
||||||
|
if not self.amount > self.dirt_properties.max_global_amount:
|
||||||
|
dirt = self.by_pos(tile.pos)
|
||||||
|
if dirt is None:
|
||||||
|
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||||
|
self.add_item(dirt)
|
||||||
|
else:
|
||||||
|
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||||
|
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||||
|
else:
|
||||||
|
return c.NOT_VALID
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = super(DirtRegister, self).__repr__()
|
||||||
|
return f'{s[:-1]}, {self.amount})'
|
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from environments.factory.base.objects import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class Dirt(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def amount(self):
|
||||||
|
return self._amount
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
# Edit this if you want items to be drawn in the ops differntly
|
||||||
|
return self._amount
|
||||||
|
|
||||||
|
def __init__(self, *args, amount=None, **kwargs):
|
||||||
|
super(Dirt, self).__init__(*args, **kwargs)
|
||||||
|
self._amount = amount
|
||||||
|
|
||||||
|
def set_new_amount(self, amount):
|
||||||
|
self._amount = amount
|
||||||
|
self._collection.notify_change_to_value(self)
|
||||||
|
|
||||||
|
def summarize_state(self, **kwargs):
|
||||||
|
state_dict = super().summarize_state(**kwargs)
|
||||||
|
state_dict.update(amount=float(self.amount))
|
||||||
|
return state_dict
|
30
environments/factory/additional/dirt/dirt_util.py
Normal file
30
environments/factory/additional/dirt/dirt_util.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
DIRT = 'Dirt'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
CLEAN_UP = 'do_cleanup_action'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsDirt(NamedTuple):
|
||||||
|
CLEAN_UP_VALID: float = 0.5
|
||||||
|
CLEAN_UP_FAIL: float = -0.1
|
||||||
|
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||||
|
|
||||||
|
|
||||||
|
class DirtProperties(NamedTuple):
|
||||||
|
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||||
|
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||||
|
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||||
|
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||||
|
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||||
|
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||||
|
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||||
|
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||||
|
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||||
|
done_when_clean: bool = True
|
@ -1,111 +1,22 @@
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, NamedTuple, Dict
|
from typing import List, Union, Dict
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
from environments.factory.base.objects import Agent, Action
|
||||||
from environments.factory.base.registers import Entities, EntityCollection
|
from environments.factory.base.registers import Entities
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
from environments.utility_classes import ObservationProperties
|
from environments.utility_classes import ObservationProperties
|
||||||
|
|
||||||
|
|
||||||
class Constants(BaseConstants):
|
|
||||||
DIRT = 'Dirt'
|
|
||||||
|
|
||||||
|
|
||||||
class Actions(BaseActions):
|
|
||||||
CLEAN_UP = 'do_cleanup_action'
|
|
||||||
|
|
||||||
|
|
||||||
class RewardsDirt(NamedTuple):
|
|
||||||
CLEAN_UP_VALID: float = 0.5
|
|
||||||
CLEAN_UP_FAIL: float = -0.1
|
|
||||||
CLEAN_UP_LAST_PIECE: float = 4.5
|
|
||||||
|
|
||||||
|
|
||||||
class DirtProperties(NamedTuple):
|
|
||||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
|
||||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
|
||||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
|
||||||
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
|
||||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
|
||||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
|
||||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
|
||||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
|
||||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
|
||||||
done_when_clean: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class Dirt(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def amount(self):
|
|
||||||
return self._amount
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
# Edit this if you want items to be drawn in the ops differntly
|
|
||||||
return self._amount
|
|
||||||
|
|
||||||
def __init__(self, *args, amount=None, **kwargs):
|
|
||||||
super(Dirt, self).__init__(*args, **kwargs)
|
|
||||||
self._amount = amount
|
|
||||||
|
|
||||||
def set_new_amount(self, amount):
|
|
||||||
self._amount = amount
|
|
||||||
self._collection.notify_change_to_value(self)
|
|
||||||
|
|
||||||
def summarize_state(self, **kwargs):
|
|
||||||
state_dict = super().summarize_state(**kwargs)
|
|
||||||
state_dict.update(amount=float(self.amount))
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
class DirtRegister(EntityCollection):
|
|
||||||
|
|
||||||
_accepted_objects = Dirt
|
|
||||||
|
|
||||||
@property
|
|
||||||
def amount(self):
|
|
||||||
return sum([dirt.amount for dirt in self])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dirt_properties(self):
|
|
||||||
return self._dirt_properties
|
|
||||||
|
|
||||||
def __init__(self, dirt_properties, *args):
|
|
||||||
super(DirtRegister, self).__init__(*args)
|
|
||||||
self._dirt_properties: DirtProperties = dirt_properties
|
|
||||||
|
|
||||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
|
||||||
if isinstance(then_dirty_tiles, Floor):
|
|
||||||
then_dirty_tiles = [then_dirty_tiles]
|
|
||||||
for tile in then_dirty_tiles:
|
|
||||||
if not self.amount > self.dirt_properties.max_global_amount:
|
|
||||||
dirt = self.by_pos(tile.pos)
|
|
||||||
if dirt is None:
|
|
||||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
|
||||||
self.add_item(dirt)
|
|
||||||
else:
|
|
||||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
|
||||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
s = super(DirtRegister, self).__repr__()
|
|
||||||
return f'{s[:-1]}, {self.amount})'
|
|
||||||
|
|
||||||
|
|
||||||
def softmax(x):
|
def softmax(x):
|
||||||
"""Compute softmax values for each sets of scores in x."""
|
"""Compute softmax values for each sets of scores in x."""
|
||||||
e_x = np.exp(x - np.max(x))
|
e_x = np.exp(x - np.max(x))
|
||||||
@ -200,8 +111,8 @@ class DirtFactory(BaseFactory):
|
|||||||
super_reward_info = super().step_hook()
|
super_reward_info = super().step_hook()
|
||||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||||
for agent in self[c.AGENT]:
|
for agent in self[c.AGENT]:
|
||||||
if agent.temp_valid and agent.last_pos != c.NO_POS:
|
if agent.step_result['action_valid'] and agent.last_pos != c.NO_POS:
|
||||||
if self._actions.is_moving_action(agent.temp_action):
|
if self._actions.is_moving_action(agent.step_result['action_name']):
|
||||||
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||||
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||||
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||||
@ -248,8 +159,8 @@ class DirtFactory(BaseFactory):
|
|||||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||||
return additional_observations
|
return additional_observations
|
||||||
|
|
||||||
def gather_additional_info(self, agent: Agent) -> dict:
|
def post_step_hook(self) -> List[Dict[str, int]]:
|
||||||
event_reward_dict = super().per_agent_reward_hook(agent)
|
super_post_step = super(DirtFactory, self).post_step_hook()
|
||||||
info_dict = dict()
|
info_dict = dict()
|
||||||
|
|
||||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||||
@ -264,8 +175,8 @@ class DirtFactory(BaseFactory):
|
|||||||
info_dict.update(dirt_amount=current_dirt_amount)
|
info_dict.update(dirt_amount=current_dirt_amount)
|
||||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||||
|
|
||||||
event_reward_dict.update({'info': info_dict})
|
super_post_step.append(info_dict)
|
||||||
return event_reward_dict
|
return super_post_step
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -304,7 +215,6 @@ if __name__ == '__main__':
|
|||||||
# inject_agents=[TSPDirtAgent],
|
# inject_agents=[TSPDirtAgent],
|
||||||
)
|
)
|
||||||
|
|
||||||
factory.save_params(Path('rewards_param'))
|
|
||||||
|
|
||||||
# noinspection DuplicatedCode
|
# noinspection DuplicatedCode
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
0
environments/factory/additional/item/__init__.py
Normal file
0
environments/factory/additional/item/__init__.py
Normal file
@ -1,179 +1,16 @@
|
|||||||
import time
|
import time
|
||||||
from collections import deque
|
from typing import List, Union, Dict
|
||||||
from typing import List, Union, NamedTuple, Dict
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from environments.factory.additional.item.item_collections import ItemRegister, Inventories, DropOffLocations
|
||||||
|
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
from environments.helpers import Constants as BaseConstants
|
from environments.factory.base.objects import Agent, Action
|
||||||
from environments.helpers import EnvActions as BaseActions
|
from environments.factory.base.registers import Entities
|
||||||
from environments import helpers as h
|
|
||||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
|
||||||
from environments.factory.base.registers import Entities, EntityCollection, BoundEnvObjCollection, ObjectCollection
|
|
||||||
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
from environments.factory.base.renderer import RenderEntity
|
||||||
|
|
||||||
|
|
||||||
class Constants(BaseConstants):
|
|
||||||
NO_ITEM = 0
|
|
||||||
ITEM_DROP_OFF = 1
|
|
||||||
# Item Env
|
|
||||||
ITEM = 'Item'
|
|
||||||
INVENTORY = 'Inventory'
|
|
||||||
DROP_OFF = 'Drop_Off'
|
|
||||||
|
|
||||||
|
|
||||||
class Actions(BaseActions):
|
|
||||||
ITEM_ACTION = 'ITEMACTION'
|
|
||||||
|
|
||||||
|
|
||||||
class RewardsItem(NamedTuple):
|
|
||||||
DROP_OFF_VALID: float = 0.1
|
|
||||||
DROP_OFF_FAIL: float = -0.1
|
|
||||||
PICK_UP_FAIL: float = -0.1
|
|
||||||
PICK_UP_VALID: float = 0.1
|
|
||||||
|
|
||||||
|
|
||||||
class Item(Entity):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._auto_despawn = -1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def auto_despawn(self):
|
|
||||||
return self._auto_despawn
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
# Edit this if you want items to be drawn in the ops differently
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def set_auto_despawn(self, auto_despawn):
|
|
||||||
self._auto_despawn = auto_despawn
|
|
||||||
|
|
||||||
def set_tile_to(self, no_pos_tile):
|
|
||||||
assert self._collection.__class__.__name__ != ItemRegister.__class__
|
|
||||||
self._tile = no_pos_tile
|
|
||||||
|
|
||||||
|
|
||||||
class ItemRegister(EntityCollection):
|
|
||||||
|
|
||||||
_accepted_objects = Item
|
|
||||||
|
|
||||||
def spawn_items(self, tiles: List[Floor]):
|
|
||||||
items = [Item(tile, self) for tile in tiles]
|
|
||||||
self.add_additional_items(items)
|
|
||||||
|
|
||||||
def despawn_items(self, items: List[Item]):
|
|
||||||
items = [items] if isinstance(items, Item) else items
|
|
||||||
for item in items:
|
|
||||||
del self[item]
|
|
||||||
|
|
||||||
|
|
||||||
class Inventory(BoundEnvObjCollection):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
|
||||||
|
|
||||||
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
|
||||||
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
|
||||||
self.capacity = capacity
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
if self._array is None:
|
|
||||||
self._array = np.zeros((1, *self._shape))
|
|
||||||
return super(Inventory, self).as_array()
|
|
||||||
|
|
||||||
def summarize_states(self, **kwargs):
|
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
|
||||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
|
||||||
attr_dict.update(dict(name=self.name))
|
|
||||||
return attr_dict
|
|
||||||
|
|
||||||
def pop(self):
|
|
||||||
item_to_pop = self[0]
|
|
||||||
self.delete_env_object(item_to_pop)
|
|
||||||
return item_to_pop
|
|
||||||
|
|
||||||
|
|
||||||
class Inventories(ObjectCollection):
|
|
||||||
|
|
||||||
_accepted_objects = Inventory
|
|
||||||
is_blocking_light = False
|
|
||||||
can_be_shadowed = False
|
|
||||||
|
|
||||||
def __init__(self, obs_shape, *args, **kwargs):
|
|
||||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
|
||||||
self._obs_shape = obs_shape
|
|
||||||
|
|
||||||
def as_array(self):
|
|
||||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
|
||||||
|
|
||||||
def spawn_inventories(self, agents, capacity):
|
|
||||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
|
||||||
for _, agent in enumerate(agents)]
|
|
||||||
self.add_additional_items(inventories)
|
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def summarize_states(self, **kwargs):
|
|
||||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocation(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return Constants.ITEM_DROP_OFF
|
|
||||||
|
|
||||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
|
||||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
|
||||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
|
||||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
|
||||||
|
|
||||||
def place_item(self, item: Item):
|
|
||||||
if self.is_full:
|
|
||||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
|
||||||
return c.NOT_VALID
|
|
||||||
else:
|
|
||||||
self.storage.append(item)
|
|
||||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
|
||||||
return c.VALID
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_full(self):
|
|
||||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
return super().summarize_state(n_steps=n_steps)
|
|
||||||
|
|
||||||
|
|
||||||
class DropOffLocations(EntityCollection):
|
|
||||||
|
|
||||||
_accepted_objects = DropOffLocation
|
|
||||||
|
|
||||||
|
|
||||||
class ItemProperties(NamedTuple):
|
|
||||||
n_items: int = 5 # How many items are there at the same time
|
|
||||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
|
||||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
|
||||||
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
|
||||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
|
||||||
|
|
||||||
|
|
||||||
c = Constants
|
c = Constants
|
||||||
a = Actions
|
a = Actions
|
||||||
|
|
87
environments/factory/additional/item/item_collections.py
Normal file
87
environments/factory/additional/item/item_collections.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from environments.factory.base.objects import Floor, Agent
|
||||||
|
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||||
|
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
||||||
|
|
||||||
|
|
||||||
|
class ItemRegister(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Item
|
||||||
|
|
||||||
|
def spawn_items(self, tiles: List[Floor]):
|
||||||
|
items = [Item(tile, self) for tile in tiles]
|
||||||
|
self.add_additional_items(items)
|
||||||
|
|
||||||
|
def despawn_items(self, items: List[Item]):
|
||||||
|
items = [items] if isinstance(items, Item) else items
|
||||||
|
for item in items:
|
||||||
|
del self[item]
|
||||||
|
|
||||||
|
|
||||||
|
class Inventory(BoundEnvObjCollection):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||||
|
|
||||||
|
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||||
|
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||||
|
self.capacity = capacity
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
if self._array is None:
|
||||||
|
self._array = np.zeros((1, *self._shape))
|
||||||
|
return super(Inventory, self).as_array()
|
||||||
|
|
||||||
|
def summarize_states(self, **kwargs):
|
||||||
|
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||||
|
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||||
|
attr_dict.update(dict(name=self.name))
|
||||||
|
return attr_dict
|
||||||
|
|
||||||
|
def pop(self):
|
||||||
|
item_to_pop = self[0]
|
||||||
|
self.delete_env_object(item_to_pop)
|
||||||
|
return item_to_pop
|
||||||
|
|
||||||
|
|
||||||
|
class Inventories(ObjectCollection):
|
||||||
|
|
||||||
|
_accepted_objects = Inventory
|
||||||
|
is_blocking_light = False
|
||||||
|
can_be_shadowed = False
|
||||||
|
|
||||||
|
def __init__(self, obs_shape, *args, **kwargs):
|
||||||
|
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||||
|
self._obs_shape = obs_shape
|
||||||
|
|
||||||
|
def as_array(self):
|
||||||
|
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||||
|
|
||||||
|
def spawn_inventories(self, agents, capacity):
|
||||||
|
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||||
|
for _, agent in enumerate(agents)]
|
||||||
|
self.add_additional_items(inventories)
|
||||||
|
|
||||||
|
def idx_by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def by_entity(self, entity):
|
||||||
|
try:
|
||||||
|
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def summarize_states(self, **kwargs):
|
||||||
|
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class DropOffLocations(EntityCollection):
|
||||||
|
|
||||||
|
_accepted_objects = DropOffLocation
|
61
environments/factory/additional/item/item_entities.py
Normal file
61
environments/factory/additional/item/item_entities.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.item.item_util import Constants
|
||||||
|
from environments.factory.base.objects import Entity
|
||||||
|
|
||||||
|
|
||||||
|
class Item(Entity):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._auto_despawn = -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def auto_despawn(self):
|
||||||
|
return self._auto_despawn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
# Edit this if you want items to be drawn in the ops differently
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def set_auto_despawn(self, auto_despawn):
|
||||||
|
self._auto_despawn = auto_despawn
|
||||||
|
|
||||||
|
def set_tile_to(self, no_pos_tile):
|
||||||
|
self._tile = no_pos_tile
|
||||||
|
|
||||||
|
def summarize_state(self, **_) -> dict:
|
||||||
|
super_summarization = super(Item, self).summarize_state()
|
||||||
|
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||||
|
return super_summarization
|
||||||
|
|
||||||
|
|
||||||
|
class DropOffLocation(Entity):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return Constants.ITEM_DROP_OFF
|
||||||
|
|
||||||
|
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||||
|
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||||
|
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||||
|
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||||
|
|
||||||
|
def place_item(self, item: Item):
|
||||||
|
if self.is_full:
|
||||||
|
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||||
|
return c.NOT_VALID
|
||||||
|
else:
|
||||||
|
self.storage.append(item)
|
||||||
|
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||||
|
return Constants.VALID
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_full(self):
|
||||||
|
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||||
|
|
||||||
|
def summarize_state(self, n_steps=None) -> dict:
|
||||||
|
if n_steps == h.STEPS_START:
|
||||||
|
return super().summarize_state(n_steps=n_steps)
|
31
environments/factory/additional/item/item_util.py
Normal file
31
environments/factory/additional/item/item_util.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||||
|
|
||||||
|
|
||||||
|
class Constants(BaseConstants):
|
||||||
|
NO_ITEM = 0
|
||||||
|
ITEM_DROP_OFF = 1
|
||||||
|
# Item Env
|
||||||
|
ITEM = 'Item'
|
||||||
|
INVENTORY = 'Inventory'
|
||||||
|
DROP_OFF = 'Drop_Off'
|
||||||
|
|
||||||
|
|
||||||
|
class Actions(BaseActions):
|
||||||
|
ITEM_ACTION = 'ITEMACTION'
|
||||||
|
|
||||||
|
|
||||||
|
class RewardsItem(NamedTuple):
|
||||||
|
DROP_OFF_VALID: float = 0.1
|
||||||
|
DROP_OFF_FAIL: float = -0.1
|
||||||
|
PICK_UP_FAIL: float = -0.1
|
||||||
|
PICK_UP_VALID: float = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class ItemProperties(NamedTuple):
|
||||||
|
n_items: int = 5 # How many items are there at the same time
|
||||||
|
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||||
|
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||||
|
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||||
|
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
@ -180,6 +180,7 @@ class BaseFactory(gym.Env):
|
|||||||
self._entities.add_additional_items({c.DOORS: doors})
|
self._entities.add_additional_items({c.DOORS: doors})
|
||||||
|
|
||||||
# Actions
|
# Actions
|
||||||
|
# TODO: Move this to Agent init, so that agents can have individual action sets.
|
||||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||||
if additional_actions := self.actions_hook:
|
if additional_actions := self.actions_hook:
|
||||||
self._actions.add_additional_items(additional_actions)
|
self._actions.add_additional_items(additional_actions)
|
||||||
@ -308,7 +309,8 @@ class BaseFactory(gym.Env):
|
|||||||
info.update(self._summarize_state())
|
info.update(self._summarize_state())
|
||||||
|
|
||||||
# Post step Hook for later use
|
# Post step Hook for later use
|
||||||
info.update(self.post_step_hook())
|
for post_step_info in self.post_step_hook():
|
||||||
|
info.update(post_step_info)
|
||||||
|
|
||||||
obs, _ = self._build_observations()
|
obs, _ = self._build_observations()
|
||||||
|
|
||||||
@ -367,14 +369,16 @@ class BaseFactory(gym.Env):
|
|||||||
agent_obs = global_agent_obs.copy()
|
agent_obs = global_agent_obs.copy()
|
||||||
agent_obs[(0, *agent.pos)] -= agent.encoding
|
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||||
else:
|
else:
|
||||||
agent_obs = global_agent_obs
|
agent_obs = global_agent_obs.copy()
|
||||||
else:
|
else:
|
||||||
|
# agent_obs == None!!!!!
|
||||||
agent_obs = global_agent_obs
|
agent_obs = global_agent_obs
|
||||||
|
|
||||||
# Build Level Observations
|
# Build Level Observations
|
||||||
if self.obs_prop.render_agents == a_obs.LEVEL:
|
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||||
|
assert agent_obs is not None
|
||||||
lvl_obs = lvl_obs.copy()
|
lvl_obs = lvl_obs.copy()
|
||||||
lvl_obs += global_agent_obs
|
lvl_obs += agent_obs
|
||||||
|
|
||||||
obs_dict[c.WALLS] = lvl_obs
|
obs_dict[c.WALLS] = lvl_obs
|
||||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||||
@ -600,7 +604,9 @@ class BaseFactory(gym.Env):
|
|||||||
for reward in agent.step_result['rewards']:
|
for reward in agent.step_result['rewards']:
|
||||||
combined_info_dict.update(reward['info'])
|
combined_info_dict.update(reward['info'])
|
||||||
|
|
||||||
|
# Combine Info dicts into a global one
|
||||||
combined_info_dict = dict(combined_info_dict)
|
combined_info_dict = dict(combined_info_dict)
|
||||||
|
|
||||||
combined_info_dict.update(info)
|
combined_info_dict.update(info)
|
||||||
|
|
||||||
global_reward_sum = sum(global_env_rewards)
|
global_reward_sum = sum(global_env_rewards)
|
||||||
@ -616,9 +622,11 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
def start_recording(self):
|
def start_recording(self):
|
||||||
self._record_episodes = True
|
self._record_episodes = True
|
||||||
|
return self._record_episodes
|
||||||
|
|
||||||
def stop_recording(self):
|
def stop_recording(self):
|
||||||
self._record_episodes = False
|
self._record_episodes = False
|
||||||
|
return not self._record_episodes
|
||||||
|
|
||||||
# noinspection PyGlobalUndefined
|
# noinspection PyGlobalUndefined
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
@ -719,12 +727,12 @@ class BaseFactory(gym.Env):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
return {}
|
return []
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def post_step_hook(self) -> dict:
|
def post_step_hook(self) -> List[dict]:
|
||||||
return {}
|
return []
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
|
@ -119,7 +119,6 @@ class Entity(EnvObject):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||||
# With Position in Env
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Missing Documentation
|
# TODO: Missing Documentation
|
||||||
|
@ -117,7 +117,7 @@ class EnvObjectCollection(ObjectCollection):
|
|||||||
return self._array
|
return self._array
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
def summarize_states(self, n_steps=None):
|
||||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
return [entity.summarize_state(n_steps=n_steps) for entity in self.values()]
|
||||||
|
|
||||||
def notify_change_to_free(self, env_object: EnvObject):
|
def notify_change_to_free(self, env_object: EnvObject):
|
||||||
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||||
|
@ -1,273 +0,0 @@
|
|||||||
from typing import Union, NamedTuple, Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from environments.factory.base.base_factory import BaseFactory
|
|
||||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
|
||||||
from environments.factory.base.registers import EntityCollection, EnvObjectCollection
|
|
||||||
from environments.factory.base.renderer import RenderEntity
|
|
||||||
from environments.helpers import Constants as BaseConstants
|
|
||||||
from environments.helpers import EnvActions as BaseActions
|
|
||||||
|
|
||||||
from environments import helpers as h
|
|
||||||
|
|
||||||
|
|
||||||
class Constants(BaseConstants):
|
|
||||||
# Battery Env
|
|
||||||
CHARGE_PODS = 'Charge_Pod'
|
|
||||||
BATTERIES = 'BATTERIES'
|
|
||||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
|
||||||
CHARGE_POD = 1
|
|
||||||
|
|
||||||
|
|
||||||
class Actions(BaseActions):
|
|
||||||
CHARGE = 'do_charge_action'
|
|
||||||
|
|
||||||
|
|
||||||
class RewardsBtry(NamedTuple):
|
|
||||||
CHARGE_VALID: float = 0.1
|
|
||||||
CHARGE_FAIL: float = -0.1
|
|
||||||
BATTERY_DISCHARGED: float = -1.0
|
|
||||||
|
|
||||||
|
|
||||||
class BatteryProperties(NamedTuple):
|
|
||||||
initial_charge: float = 0.8 #
|
|
||||||
charge_rate: float = 0.4 #
|
|
||||||
charge_locations: int = 20 #
|
|
||||||
per_action_costs: Union[dict, float] = 0.02
|
|
||||||
done_when_discharged = False
|
|
||||||
multi_charge: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
c = Constants
|
|
||||||
a = Actions
|
|
||||||
|
|
||||||
|
|
||||||
class Battery(BoundingMixin, EnvObject):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_discharged(self):
|
|
||||||
return self.charge_level == 0
|
|
||||||
|
|
||||||
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
|
||||||
super(Battery, self).__init__(*args, **kwargs)
|
|
||||||
self.charge_level = initial_charge_level
|
|
||||||
|
|
||||||
def encoding(self):
|
|
||||||
return self.charge_level
|
|
||||||
|
|
||||||
def do_charge_action(self, amount):
|
|
||||||
if self.charge_level < 1:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
self.charge_level = min(1, amount + self.charge_level)
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def decharge(self, amount) -> c:
|
|
||||||
if self.charge_level != 0:
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
self.charge_level = max(0, amount + self.charge_level)
|
|
||||||
self._collection.notify_change_to_value(self)
|
|
||||||
return c.VALID
|
|
||||||
else:
|
|
||||||
return c.NOT_VALID
|
|
||||||
|
|
||||||
def summarize_state(self, **_):
|
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
|
||||||
attr_dict.update(dict(name=self.name))
|
|
||||||
return attr_dict
|
|
||||||
|
|
||||||
|
|
||||||
class BatteriesRegister(EnvObjectCollection):
|
|
||||||
|
|
||||||
_accepted_objects = Battery
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
|
||||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
|
||||||
self.is_observable = True
|
|
||||||
|
|
||||||
def spawn_batteries(self, agents, initial_charge_level):
|
|
||||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
|
||||||
self.add_additional_items(batteries)
|
|
||||||
|
|
||||||
def summarize_states(self, n_steps=None):
|
|
||||||
# as dict with additional nesting
|
|
||||||
# return dict(items=super(Inventories, cls).summarize_states())
|
|
||||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
|
||||||
|
|
||||||
# Todo Move this to Mixin!
|
|
||||||
def by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def idx_by_entity(self, entity):
|
|
||||||
try:
|
|
||||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def as_array_by_entity(self, entity):
|
|
||||||
return self._array[self.idx_by_entity(entity)]
|
|
||||||
|
|
||||||
|
|
||||||
class ChargePod(Entity):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def encoding(self):
|
|
||||||
return c.CHARGE_POD
|
|
||||||
|
|
||||||
def __init__(self, *args, charge_rate: float = 0.4,
|
|
||||||
multi_charge: bool = False, **kwargs):
|
|
||||||
super(ChargePod, self).__init__(*args, **kwargs)
|
|
||||||
self.charge_rate = charge_rate
|
|
||||||
self.multi_charge = multi_charge
|
|
||||||
|
|
||||||
def charge_battery(self, battery: Battery):
|
|
||||||
if battery.charge_level == 1.0:
|
|
||||||
return c.NOT_VALID
|
|
||||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
|
||||||
return c.NOT_VALID
|
|
||||||
valid = battery.do_charge_action(self.charge_rate)
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def summarize_state(self, n_steps=None) -> dict:
|
|
||||||
if n_steps == h.STEPS_START:
|
|
||||||
summary = super().summarize_state(n_steps=n_steps)
|
|
||||||
return summary
|
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(EntityCollection):
|
|
||||||
|
|
||||||
_accepted_objects = ChargePod
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
super(ChargePods, self).__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class BatteryFactory(BaseFactory):
|
|
||||||
|
|
||||||
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
|
||||||
**kwargs):
|
|
||||||
if isinstance(btry_prop, dict):
|
|
||||||
btry_prop = BatteryProperties(**btry_prop)
|
|
||||||
if isinstance(rewards_dest, dict):
|
|
||||||
rewards_dest = BatteryProperties(**rewards_dest)
|
|
||||||
self.btry_prop = btry_prop
|
|
||||||
self.rewards_dest = rewards_dest
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
|
||||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
|
||||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
|
||||||
return additional_raw_observations
|
|
||||||
|
|
||||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
|
||||||
additional_observations = super().observations_hook()
|
|
||||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
|
||||||
return additional_observations
|
|
||||||
|
|
||||||
@property
|
|
||||||
def entities_hook(self):
|
|
||||||
super_entities = super().entities_hook
|
|
||||||
|
|
||||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
|
||||||
charge_pods = ChargePods.from_tiles(
|
|
||||||
empty_tiles, self._level_shape,
|
|
||||||
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
|
||||||
multi_charge=self.btry_prop.multi_charge)
|
|
||||||
)
|
|
||||||
|
|
||||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
|
||||||
)
|
|
||||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
|
||||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
|
||||||
return super_entities
|
|
||||||
|
|
||||||
def step_hook(self) -> (List[dict], dict):
|
|
||||||
super_reward_info = super(BatteryFactory, self).step_hook()
|
|
||||||
|
|
||||||
# Decharge
|
|
||||||
batteries = self[c.BATTERIES]
|
|
||||||
|
|
||||||
for agent in self[c.AGENT]:
|
|
||||||
if isinstance(self.btry_prop.per_action_costs, dict):
|
|
||||||
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
|
||||||
else:
|
|
||||||
energy_consumption = self.btry_prop.per_action_costs
|
|
||||||
|
|
||||||
batteries.by_entity(agent).decharge(energy_consumption)
|
|
||||||
|
|
||||||
return super_reward_info
|
|
||||||
|
|
||||||
def do_charge_action(self, agent) -> (dict, dict):
|
|
||||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
|
||||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
|
||||||
if valid:
|
|
||||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
|
||||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
|
||||||
else:
|
|
||||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
|
||||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
|
||||||
else:
|
|
||||||
valid = c.NOT_VALID
|
|
||||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
|
||||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
|
||||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
|
||||||
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
|
||||||
reason=a.CHARGE, info=info_dict)
|
|
||||||
return valid, reward
|
|
||||||
|
|
||||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
|
||||||
action_result = super().do_additional_actions(agent, action)
|
|
||||||
if action_result is None:
|
|
||||||
if action == a.CHARGE:
|
|
||||||
action_result = self.do_charge_action(agent)
|
|
||||||
return action_result
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return action_result
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset_hook(self) -> None:
|
|
||||||
# There is Nothing to reset.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def check_additional_done(self) -> (bool, dict):
|
|
||||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
|
||||||
if super_done:
|
|
||||||
return super_done, super_dict
|
|
||||||
else:
|
|
||||||
if self.btry_prop.done_when_discharged:
|
|
||||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
|
||||||
super_dict.update(DISCHARGE_DONE=1)
|
|
||||||
return btry_done, super_dict
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
pass
|
|
||||||
|
|
||||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
|
||||||
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
|
||||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
|
||||||
self.print(f'{agent.name} Battery is discharged!')
|
|
||||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
|
||||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
|
||||||
'info': info_dict}}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# All Fine
|
|
||||||
pass
|
|
||||||
return reward_event_dict
|
|
||||||
|
|
||||||
def render_assets_hook(self):
|
|
||||||
# noinspection PyUnresolvedReferences
|
|
||||||
additional_assets = super().render_assets_hook()
|
|
||||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
|
||||||
additional_assets.extend(charge_pods)
|
|
||||||
return additional_assets
|
|
@ -3,7 +3,9 @@ from typing import Dict, List, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from environments.factory.base.objects import Agent, Entity, Action
|
from environments.factory.base.objects import Agent, Entity, Action
|
||||||
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.additional.dirt.dirt_collections import DirtRegister
|
||||||
|
from environments.factory.additional.dirt.dirt_entity import Dirt
|
||||||
from environments.factory.base.objects import Floor
|
from environments.factory.base.objects import Floor
|
||||||
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||||
|
|
||||||
@ -28,7 +30,6 @@ class StationaryMachinesDirtFactory(DirtFactory):
|
|||||||
|
|
||||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||||
super_entities = super().entities_hook()
|
super_entities = super().entities_hook()
|
||||||
|
|
||||||
return super_entities
|
return super_entities
|
||||||
|
|
||||||
def reset_hook(self) -> None:
|
def reset_hook(self) -> None:
|
||||||
@ -48,8 +49,8 @@ class StationaryMachinesDirtFactory(DirtFactory):
|
|||||||
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||||
return super_per_agent_raw_observations
|
return super_per_agent_raw_observations
|
||||||
|
|
||||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||||
pass
|
return super(StationaryMachinesDirtFactory, self).per_agent_reward_hook(agent)
|
||||||
|
|
||||||
def pre_step_hook(self) -> None:
|
def pre_step_hook(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -173,9 +173,9 @@ class RewardsBase(NamedTuple):
|
|||||||
|
|
||||||
class ObservationTranslator:
|
class ObservationTranslator:
|
||||||
|
|
||||||
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
def __init__(self, this_named_observation_space: Dict[str, dict],
|
||||||
*per_agent_named_obs_spaces: Dict[str, dict],
|
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||||
placeholder_fill_value: Union[int, str] = 'N'):
|
placeholder_fill_value: Union[int, str, None] = None):
|
||||||
"""
|
"""
|
||||||
This is a helper class, which converts agents observations from joined environments.
|
This is a helper class, which converts agents observations from joined environments.
|
||||||
For example, agents trained in different environments may expect different observations.
|
For example, agents trained in different environments may expect different observations.
|
||||||
@ -183,8 +183,6 @@ class ObservationTranslator:
|
|||||||
A string identifier based approach is used.
|
A string identifier based approach is used.
|
||||||
Currently, it is not possible to mix different obs shapes.
|
Currently, it is not possible to mix different obs shapes.
|
||||||
|
|
||||||
:param obs_shape_2d: The shape of the observation the agents expect.
|
|
||||||
:type obs_shape_2d: tuple(int, int)
|
|
||||||
|
|
||||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||||
:type this_named_observation_space: Dict[str, dict]
|
:type this_named_observation_space: Dict[str, dict]
|
||||||
@ -196,15 +194,15 @@ class ObservationTranslator:
|
|||||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(obs_shape_2d) == 2
|
|
||||||
self.obs_shape = obs_shape_2d
|
|
||||||
if isinstance(placeholder_fill_value, str):
|
if isinstance(placeholder_fill_value, str):
|
||||||
if placeholder_fill_value.lower() in ['normal', 'n']:
|
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||||
self.random_fill = lambda: np.random.normal(size=self.obs_shape)
|
self.random_fill = np.random.normal
|
||||||
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||||
self.random_fill = lambda: np.random.uniform(size=self.obs_shape)
|
self.random_fill = np.random.uniform
|
||||||
else:
|
else:
|
||||||
raise ValueError('Please chooe between "uniform" or "normal"')
|
raise ValueError('Please chooe between "uniform" or "normal" ("u", "n").')
|
||||||
|
elif isinstance(placeholder_fill_value, int):
|
||||||
|
raise NotImplementedError('"Future Work."')
|
||||||
else:
|
else:
|
||||||
self.random_fill = None
|
self.random_fill = None
|
||||||
|
|
||||||
@ -213,9 +211,21 @@ class ObservationTranslator:
|
|||||||
|
|
||||||
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||||
translation = [idx_space_dict for name, idx_space_dict in target_obs_space.items()]
|
translation = dict()
|
||||||
flat_translation = [x for y in translation for x in y]
|
for name, idxs in target_obs_space.items():
|
||||||
return np.take(obs, flat_translation, axis=1 if obs.ndim == 4 else 0)
|
if name in self._this_named_obs_space:
|
||||||
|
for target_idx, this_idx in zip(idxs, self._this_named_obs_space[name]):
|
||||||
|
taken_slice = np.take(obs, [this_idx], axis=1 if obs.ndim == 4 else 0)
|
||||||
|
translation[target_idx] = taken_slice
|
||||||
|
elif random_fill := self.random_fill:
|
||||||
|
for target_idx in idxs:
|
||||||
|
translation[target_idx] = random_fill(size=obs.shape[:-3] + (1,) + obs.shape[-2:])
|
||||||
|
else:
|
||||||
|
for target_idx in idxs:
|
||||||
|
translation[target_idx] = np.zeros(shape=(obs.shape[:-3] + (1,) + obs.shape[-2:]))
|
||||||
|
|
||||||
|
translation = dict(sorted(translation.items()))
|
||||||
|
return np.concatenate(list(translation.values()), axis=-3)
|
||||||
|
|
||||||
def translate_observations(self, observations: List[ArrayLike]):
|
def translate_observations(self, observations: List[ArrayLike]):
|
||||||
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||||
@ -241,6 +251,9 @@ class ActionTranslator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self._target_named_action_space = target_named_action_space
|
self._target_named_action_space = target_named_action_space
|
||||||
|
if isinstance(per_agent_named_action_space, (list, tuple)):
|
||||||
|
self._per_agent_named_action_space = per_agent_named_action_space
|
||||||
|
else:
|
||||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||||
|
|
||||||
|
@ -71,8 +71,8 @@ class EnvMonitor(BaseCallback):
|
|||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_run(self, filepath: Union[Path, str], auto_plotting_keys=None):
|
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||||
filepath = Path(filepath)
|
filepath = Path(filepath or self._filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with filepath.open('wb') as f:
|
with filepath.open('wb') as f:
|
||||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -12,11 +13,14 @@ from environments.factory.base.base_factory import REC_TAC
|
|||||||
|
|
||||||
class EnvRecorder(BaseCallback):
|
class EnvRecorder(BaseCallback):
|
||||||
|
|
||||||
def __init__(self, env, entities='all'):
|
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
|
||||||
super(EnvRecorder, self).__init__()
|
super(EnvRecorder, self).__init__()
|
||||||
|
self.filepath = filepath
|
||||||
self.unwrapped = env
|
self.unwrapped = env
|
||||||
|
self.freq = freq
|
||||||
self._recorder_dict = defaultdict(list)
|
self._recorder_dict = defaultdict(list)
|
||||||
self._recorder_out_list = list()
|
self._recorder_out_list = list()
|
||||||
|
self._episode_counter = 1
|
||||||
if isinstance(entities, str):
|
if isinstance(entities, str):
|
||||||
if entities.lower() == 'all':
|
if entities.lower() == 'all':
|
||||||
self._entities = None
|
self._entities = None
|
||||||
@ -29,46 +33,47 @@ class EnvRecorder(BaseCallback):
|
|||||||
return getattr(self.unwrapped, item)
|
return getattr(self.unwrapped, item)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.unwrapped.start_recording()
|
self._on_training_start()
|
||||||
return self.unwrapped.reset()
|
return self.unwrapped.reset()
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
def _on_training_start(self) -> None:
|
||||||
self.unwrapped._record_episodes = True
|
assert self.start_recording()
|
||||||
pass
|
|
||||||
|
|
||||||
def _read_info(self, env_idx, info: dict):
|
def _read_info(self, env_idx, info: dict):
|
||||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||||
if self._entities:
|
if self._entities:
|
||||||
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
||||||
|
|
||||||
info_dict.update(episode=(self.num_timesteps + env_idx))
|
|
||||||
self._recorder_dict[env_idx].append(info_dict)
|
self._recorder_dict[env_idx].append(info_dict)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return
|
return True
|
||||||
|
|
||||||
def _read_done(self, env_idx, done):
|
def _read_done(self, env_idx, done):
|
||||||
if done:
|
if done:
|
||||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||||
'episode': len(self._recorder_out_list)})
|
'episode': self._episode_counter})
|
||||||
self._recorder_dict[env_idx] = list()
|
self._recorder_dict[env_idx] = list()
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
step_result = self.unwrapped.step(actions)
|
step_result = self.unwrapped.step(actions)
|
||||||
# 0, 1, 2 , 3 = idx
|
self._on_step()
|
||||||
# _, _, done_bool, info_obj = step_result
|
|
||||||
self._read_info(0, step_result[3])
|
|
||||||
self._read_done(0, step_result[2])
|
|
||||||
return step_result
|
return step_result
|
||||||
|
|
||||||
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
def finalize(self):
|
||||||
filepath = Path(filepath)
|
self._on_training_end()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def save_records(self, filepath: Union[Path, str, None] = None, save_occupation_map=False, save_trajectory_map=False):
|
||||||
|
filepath = Path(filepath or self.filepath)
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
# cls.out_file.unlink(missing_ok=True)
|
# cls.out_file.unlink(missing_ok=True)
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
out_dict = {'episodes': self._recorder_out_list, 'header': self.unwrapped.params}
|
out_dict = {'n_episodes': self._episode_counter,
|
||||||
|
'header': self.unwrapped.params,
|
||||||
|
'episodes': self._recorder_out_list
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
simplejson.dump(out_dict, f, indent=4)
|
simplejson.dump(out_dict, f, indent=4)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -76,6 +81,7 @@ class EnvRecorder(BaseCallback):
|
|||||||
|
|
||||||
if save_occupation_map:
|
if save_occupation_map:
|
||||||
a = np.zeros((15, 15))
|
a = np.zeros((15, 15))
|
||||||
|
# noinspection PyTypeChecker
|
||||||
for episode in out_dict['episodes']:
|
for episode in out_dict['episodes']:
|
||||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||||
|
|
||||||
@ -94,15 +100,22 @@ class EnvRecorder(BaseCallback):
|
|||||||
raise NotImplementedError('This has not yet been implemented.')
|
raise NotImplementedError('This has not yet been implemented.')
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
|
do_record = self.freq == -1 or self._episode_counter % self.freq == 0
|
||||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||||
|
if do_record:
|
||||||
self._read_info(env_idx, info)
|
self._read_info(env_idx, info)
|
||||||
|
|
||||||
dones = list(enumerate(self.locals.get('dones', [])))
|
dones = list(enumerate(self.locals.get('dones', [])))
|
||||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
dones.extend(list(enumerate(self.locals.get('done', []))))
|
||||||
for env_idx, done in dones:
|
for env_idx, done in dones:
|
||||||
|
if do_record:
|
||||||
self._read_done(env_idx, done)
|
self._read_done(env_idx, done)
|
||||||
|
if done:
|
||||||
|
self._episode_counter += 1
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
def _on_training_end(self) -> None:
|
||||||
|
for env_idx in range(len(self._recorder_dict)):
|
||||||
|
if self._recorder_dict[env_idx]:
|
||||||
|
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||||
|
'episode': self._episode_counter})
|
||||||
pass
|
pass
|
||||||
|
@ -3,6 +3,21 @@ import gym
|
|||||||
from gym.wrappers.frame_stack import FrameStack
|
from gym.wrappers.frame_stack import FrameStack
|
||||||
|
|
||||||
|
|
||||||
|
class EnvCombiner(object):
|
||||||
|
|
||||||
|
def __init__(self, *envs_cls):
|
||||||
|
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_cls(name, *envs_cls):
|
||||||
|
return type(name,envs_cls,{})
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||||
|
|
||||||
|
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||||
|
|
||||||
|
|
||||||
class AgentRenderOptions(object):
|
class AgentRenderOptions(object):
|
||||||
"""
|
"""
|
||||||
Class that specifies the available options for the way agents are represented in the env observation.
|
Class that specifies the available options for the way agents are represented in the env observation.
|
||||||
@ -46,7 +61,7 @@ class ObservationProperties(NamedTuple):
|
|||||||
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""How to represent agents in the observation space. This may also alters the obs-shape."""
|
"""How to represent agents in the observation space. This may also alter the obs-shape."""
|
||||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||||
|
|
||||||
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
|
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
|
||||||
|
@ -76,10 +76,11 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
|||||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
|
|
||||||
if run_path.is_dir():
|
run_path.mkdir(parents=True, exist_ok=True)
|
||||||
prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
if run_path.exists() and run_path.is_file():
|
||||||
elif run_path.exists() and run_path.is_file():
|
prepare_plot(run_path.parent / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
else:
|
||||||
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||||
print('Plotting done.')
|
print('Plotting done.')
|
||||||
|
|
||||||
|
|
||||||
|
187
quickstart/combine_and_monitor_rerun.py
Normal file
187
quickstart/combine_and_monitor_rerun.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
##############################################
|
||||||
|
# keep this for stand alone script execution #
|
||||||
|
##############################################
|
||||||
|
from environments.factory.base.base_factory import BaseFactory
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
##############################################
|
||||||
|
##############################################
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.combined_factories import DestBatteryFactory
|
||||||
|
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||||
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.additional.item.factory_item import ItemFactory
|
||||||
|
from environments.helpers import ObservationTranslator, ActionTranslator
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.utility_classes import ObservationProperties, AgentRenderOptions, MovementProperties
|
||||||
|
|
||||||
|
|
||||||
|
def policy_model_kwargs():
|
||||||
|
return dict(ent_coef=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def dqn_model_kwargs():
|
||||||
|
return dict(buffer_size=50000,
|
||||||
|
learning_starts=64,
|
||||||
|
batch_size=64,
|
||||||
|
target_update_interval=5000,
|
||||||
|
exploration_fraction=0.25,
|
||||||
|
exploration_final_eps=0.025
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
with env_fctry(**env_kwrgs) as init_env:
|
||||||
|
return init_env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# Define Global Env Parameters
|
||||||
|
# Define properties object parameters
|
||||||
|
factory_kwargs = dict(
|
||||||
|
max_steps=400, parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, verbose=False,
|
||||||
|
mv_prop=MovementProperties(allow_diagonal_movement=True,
|
||||||
|
allow_square_movement=True,
|
||||||
|
allow_no_op=False),
|
||||||
|
obs_prop=ObservationProperties(
|
||||||
|
frames_to_stack=3,
|
||||||
|
cast_shadows=True,
|
||||||
|
omit_agent_self=True,
|
||||||
|
render_agents=AgentRenderOptions.LEVEL,
|
||||||
|
additional_agent_placeholder=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Bundle both environments with global kwargs and parameters
|
||||||
|
# Todo: find a better solution, like outo module loading
|
||||||
|
env_map = {'DirtFactory': DirtFactory,
|
||||||
|
'ItemFactory': ItemFactory,
|
||||||
|
'DestFactory': DestFactory,
|
||||||
|
'DestBatteryFactory': DestBatteryFactory
|
||||||
|
}
|
||||||
|
env_names = list(env_map.keys())
|
||||||
|
|
||||||
|
# Put all your multi-seed agends in a single folder, we do not need specific names etc.
|
||||||
|
available_models = dict()
|
||||||
|
available_envs = dict()
|
||||||
|
available_runs_kwargs = dict()
|
||||||
|
available_runs_agents = dict()
|
||||||
|
max_seed = 0
|
||||||
|
# Define this folder
|
||||||
|
combinations_path = Path('combinations')
|
||||||
|
# Those are all differently trained combinations of mdoels, env and parameters
|
||||||
|
for combination in (x for x in combinations_path.iterdir() if x.is_dir()):
|
||||||
|
# These are all the models for this specific combination
|
||||||
|
for model_run in (x for x in combination.iterdir() if x.is_dir()):
|
||||||
|
model_name, env_name = model_run.name.split('_')[:2]
|
||||||
|
if model_name not in available_models:
|
||||||
|
available_models[model_name] = h.MODEL_MAP[model_name]
|
||||||
|
if env_name not in available_envs:
|
||||||
|
available_envs[env_name] = env_map[env_name]
|
||||||
|
# Those are all available seeds
|
||||||
|
for seed_run in (x for x in model_run.iterdir() if x.is_dir()):
|
||||||
|
max_seed = max(int(seed_run.name.split('_')[0]), max_seed)
|
||||||
|
# Read the env configuration from ROM
|
||||||
|
with next(seed_run.glob('env_params.json')).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
available_runs_kwargs[seed_run.name] = env_kwargs
|
||||||
|
# Read the trained model_path from ROM
|
||||||
|
model_path = next(seed_run.glob('model.zip'))
|
||||||
|
available_runs_agents[seed_run.name] = model_path
|
||||||
|
|
||||||
|
# We start by combining all SAME MODEL CLASSES per available Seed, across ALL available ENVIRONMENTS.
|
||||||
|
for model_name, model_cls in available_models.items():
|
||||||
|
for seed in range(max_seed):
|
||||||
|
combined_env_kwargs = dict()
|
||||||
|
model_paths = list()
|
||||||
|
comparable_runs = {key: val for key, val in available_runs_kwargs.items() if (
|
||||||
|
key.startswith(str(seed)) and model_name in key and key != 'key')
|
||||||
|
}
|
||||||
|
for name, run_kwargs in comparable_runs.items():
|
||||||
|
# Select trained agent as a candidate:
|
||||||
|
model_paths.append(available_runs_agents[name])
|
||||||
|
# Sort Env Kwars:
|
||||||
|
for key, val in run_kwargs.items():
|
||||||
|
if key not in combined_env_kwargs:
|
||||||
|
combined_env_kwargs.update(dict(key=val))
|
||||||
|
else:
|
||||||
|
assert combined_env_kwargs[key] == val, "Check the combinations you try to make!"
|
||||||
|
|
||||||
|
# Update and combine all kwargs to account for multiple agents etc.
|
||||||
|
# We cannot capture all configuration cases!
|
||||||
|
for key, val in factory_kwargs.items():
|
||||||
|
if key not in combined_env_kwargs:
|
||||||
|
combined_env_kwargs[key] = val
|
||||||
|
else:
|
||||||
|
assert combined_env_kwargs[key] == val
|
||||||
|
combined_env_kwargs.update(n_agents=len(comparable_runs))
|
||||||
|
|
||||||
|
with(type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs)) as combEnv:
|
||||||
|
# EnvMonitor Init
|
||||||
|
comb = f'comb_{model_name}_{seed}'
|
||||||
|
comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick'
|
||||||
|
comb_recorder_path = combinations_path / comb / f'{comb}_recorder.pick'
|
||||||
|
comb_monitor_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path)
|
||||||
|
# monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_monitor_path)
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# Load all models
|
||||||
|
loaded_models = [available_models[model_name].load(model_path) for model_path in model_paths]
|
||||||
|
obs_translators = ObservationTranslator(
|
||||||
|
monitoredCombEnv.named_observation_space,
|
||||||
|
*[agent.named_observation_space for agent in loaded_models],
|
||||||
|
placeholder_fill_value='n')
|
||||||
|
act_translators = ActionTranslator(
|
||||||
|
monitoredCombEnv.named_action_space,
|
||||||
|
*(agent.named_action_space for agent in loaded_models)
|
||||||
|
)
|
||||||
|
|
||||||
|
for episode in range(50):
|
||||||
|
obs, _ = monitoredCombEnv.reset(), monitoredCombEnv.render()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
actions = []
|
||||||
|
for i, model in enumerate(loaded_models):
|
||||||
|
pred = model.predict(obs_translators.translate_observation(i, obs[i]))[0]
|
||||||
|
actions.append(act_translators.translate_action(i, pred))
|
||||||
|
|
||||||
|
obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions)
|
||||||
|
|
||||||
|
rew += step_r
|
||||||
|
monitoredCombEnv.render()
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
# Eval monitor outputs are automatically stored by the monitor object
|
||||||
|
# TODO: Plotting
|
||||||
|
monitoredCombEnv.save_records(comb_monitor_path)
|
||||||
|
monitoredCombEnv.save_run()
|
||||||
|
pass
|
203
quickstart/single_agent_train_battery_target_env.py
Normal file
203
quickstart/single_agent_train_battery_target_env.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||||
|
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.factory.additional.combined_factories import DestBatteryFactory
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (dirt-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = DestBatteryFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (dirt-factory).
|
||||||
|
|
||||||
|
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'DirtProperties' control if and how dirt is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
dest_props = DestProperties(
|
||||||
|
n_dests = 2, # How many destinations are there
|
||||||
|
dwell_time = 0, # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency = 0,
|
||||||
|
spawn_in_other_zone = True, #
|
||||||
|
spawn_mode = DestModeOptions.DONE,
|
||||||
|
)
|
||||||
|
btry_props = BatteryProperties(
|
||||||
|
initial_charge = 0.9, #
|
||||||
|
charge_rate = 0.4, #
|
||||||
|
charge_locations = 3, #
|
||||||
|
per_action_costs = 0.01,
|
||||||
|
done_when_discharged = True,
|
||||||
|
multi_charge = False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
dest_prop=dest_props,
|
||||||
|
btry_prop=btry_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory, verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
193
quickstart/single_agent_train_dest_env.py
Normal file
193
quickstart/single_agent_train_dest_env.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (dest-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = DestFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (dest-factory).
|
||||||
|
|
||||||
|
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'DestProperties' control if and how dest is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
dest_props = DestProperties(
|
||||||
|
n_dests = 2, # How many destinations are there
|
||||||
|
dwell_time = 0, # How long does the agent need to "wait" on a destination
|
||||||
|
spawn_frequency = 0,
|
||||||
|
spawn_in_other_zone = True, #
|
||||||
|
spawn_mode = DestModeOptions.DONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
dest_prop=dest_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
@ -1,11 +1,12 @@
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from matplotlib import pyplot as plt
|
import simplejson
|
||||||
import itertools as it
|
|
||||||
|
|
||||||
import stable_baselines3 as sb3
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
try:
|
try:
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
if __package__ is None:
|
if __package__ is None:
|
||||||
@ -18,19 +19,14 @@ except NameError:
|
|||||||
DIR = None
|
DIR = None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
import simplejson
|
|
||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
|
||||||
from environments.logging.envmonitor import EnvMonitor
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||||
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
import pickle
|
|
||||||
from plotting.compare_runs import compare_seed_runs, compare_model_runs
|
|
||||||
import pandas as pd
|
|
||||||
import seaborn as sns
|
|
||||||
|
|
||||||
import multiprocessing as mp
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Welcome to this quick start file. Here we will see how to:
|
Welcome to this quick start file. Here we will see how to:
|
||||||
@ -53,6 +49,8 @@ if __name__ == '__main__':
|
|||||||
model_class = sb3.PPO
|
model_class = sb3.PPO
|
||||||
env_class = DirtFactory
|
env_class = DirtFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
# Define a global studi save path
|
# Define a global studi save path
|
||||||
start_time = int(time.time())
|
start_time = int(time.time())
|
||||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
@ -100,7 +98,7 @@ if __name__ == '__main__':
|
|||||||
mv_prop=move_props, # See Above
|
mv_prop=move_props, # See Above
|
||||||
obs_prop=obs_props, # See Above
|
obs_prop=obs_props, # See Above
|
||||||
done_at_collision=True,
|
done_at_collision=True,
|
||||||
dirt_props=dirt_props
|
dirt_prop=dirt_props
|
||||||
)
|
)
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
@ -120,30 +118,37 @@ if __name__ == '__main__':
|
|||||||
seed_path.mkdir(parents=True, exist_ok=True)
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Parameter Storage
|
# Parameter Storage
|
||||||
param_path = seed_path / f'env_params.json'
|
param_path = seed_path / env_params_json
|
||||||
# Observation (measures) Storage
|
# Observation (measures) Storage
|
||||||
monitor_path = seed_path / 'monitor.pick'
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
# Model save Path for the trained model
|
# Model save Path for the trained model
|
||||||
model_save_path = seed_path / f'model.zip'
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
# Env Init & Model kwargs definition
|
# Env Init & Model kwargs definition
|
||||||
with DirtFactory(env_kwargs) as env_factory:
|
with env_class(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
# EnvMonitor Init
|
# EnvMonitor Init
|
||||||
env_monitor_callback = EnvMonitor(env_factory)
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
# Model Init
|
# Model Init
|
||||||
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
# Model train
|
# Model train
|
||||||
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback])
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
# 3. Save env and agent for later analysis.
|
# 3. Save env and agent for later analysis.
|
||||||
# Save the trained Model, the monitor (env measures) and the env parameters
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
model.save(model_save_path)
|
model.save(model_save_path)
|
||||||
env_factory.save_params(param_path)
|
env_factory.save_params(param_path)
|
||||||
env_monitor_callback.save_run(monitor_path)
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
# Compare performance runs, for each seed within a model
|
# Compare performance runs, for each seed within a model
|
||||||
try:
|
try:
|
||||||
@ -164,18 +169,19 @@ if __name__ == '__main__':
|
|||||||
# Load the agent agent
|
# Load the agent agent
|
||||||
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
# Load old env kwargs
|
# Load old env kwargs
|
||||||
with next(policy_path.glob('*.json')).open('r') as f:
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
env_kwargs = simplejson.load(f)
|
env_kwargs = simplejson.load(f)
|
||||||
# Make the env stop ar collisions
|
# Make the env stop ar collisions
|
||||||
# (you only want to have a single collision per episode hence the statistics)
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
env_kwargs.update(done_at_collision=True)
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
# Init Env
|
# Init Env
|
||||||
with env_to_run(**env_kwargs) as env_factory:
|
with env_class(**env_kwargs) as env_factory:
|
||||||
monitored_env_factory = EnvMonitor(env_factory)
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
# Evaluation Loop for i in range(n Episodes)
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
for episode in range(100):
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
env_state = monitored_env_factory.reset()
|
env_state = monitored_env_factory.reset()
|
||||||
rew, done_bool = 0, False
|
rew, done_bool = 0, False
|
||||||
while not done_bool:
|
while not done_bool:
|
||||||
@ -185,8 +191,5 @@ if __name__ == '__main__':
|
|||||||
if done_bool:
|
if done_bool:
|
||||||
break
|
break
|
||||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
monitored_env_factory.save_run(filepath=policy_path / f'{baseline_monitor_file}.pick')
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
|
||||||
# for policy_path in (y for y in policy_path.iterdir() if y.is_dir()):
|
|
||||||
# load_model_run_baseline(policy_path)
|
|
||||||
print('Measurements Done')
|
print('Measurements Done')
|
||||||
|
191
quickstart/single_agent_train_item_env.py
Normal file
191
quickstart/single_agent_train_item_env.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import simplejson
|
||||||
|
|
||||||
|
import stable_baselines3 as sb3
|
||||||
|
|
||||||
|
# This is needed, when you put this file in a subfolder.
|
||||||
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
|
if __package__ is None:
|
||||||
|
DIR = Path(__file__).resolve().parent
|
||||||
|
sys.path.insert(0, str(DIR.parent))
|
||||||
|
__package__ = DIR.name
|
||||||
|
else:
|
||||||
|
DIR = None
|
||||||
|
except NameError:
|
||||||
|
DIR = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
from environments import helpers as h
|
||||||
|
from environments.factory.additional.item.factory_item import ItemFactory
|
||||||
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
|
||||||
|
from plotting.compare_runs import compare_seed_runs
|
||||||
|
|
||||||
|
"""
|
||||||
|
Welcome to this quick start file. Here we will see how to:
|
||||||
|
0. Setup I/O Paths
|
||||||
|
1. Setup parameters for the environments (item-factory).
|
||||||
|
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
Run the training.
|
||||||
|
3. Save env and agent for later analysis.
|
||||||
|
4. Load the agent from drive
|
||||||
|
5. Rendering the env with a run of the trained agent.
|
||||||
|
6. Plot metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
#########################################################
|
||||||
|
# 0. Setup I/O Paths
|
||||||
|
# Define some general parameters
|
||||||
|
train_steps = 1e6
|
||||||
|
n_seeds = 3
|
||||||
|
model_class = sb3.PPO
|
||||||
|
env_class = ItemFactory
|
||||||
|
|
||||||
|
env_params_json = 'env_params.json'
|
||||||
|
|
||||||
|
# Define a global studi save path
|
||||||
|
start_time = int(time.time())
|
||||||
|
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||||
|
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||||
|
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||||
|
exp_path = study_root_path / identifier
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1. Setup parameters for the environments (item-factory).
|
||||||
|
#
|
||||||
|
# Define property object parameters.
|
||||||
|
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||||
|
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||||
|
omit_agent_self=True, # This is default
|
||||||
|
additional_agent_placeholder=None, # We will not take care of future agents
|
||||||
|
frames_to_stack=3, # To give the agent a notion of time
|
||||||
|
pomdp_r=2 # the agents view-radius
|
||||||
|
)
|
||||||
|
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||||
|
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||||
|
allow_square_movement=True, # Manhattan (edges)
|
||||||
|
allow_no_op=False) # Pause movement (do nothing)
|
||||||
|
|
||||||
|
# 'ItemProperties' control if and how item is spawned
|
||||||
|
# TODO: Comments
|
||||||
|
item_props = ItemProperties(
|
||||||
|
n_items = 7, # How many items are there at the same time
|
||||||
|
spawn_frequency = 50, # Spawn Frequency in Steps
|
||||||
|
n_drop_off_locations = 10, # How many DropOff locations are there at the same time
|
||||||
|
max_dropoff_storage_size = 0, # How many items are needed until the dropoff is full
|
||||||
|
max_agent_inventory_capacity = 5, # How many items are needed until the agent inventory is full)
|
||||||
|
)
|
||||||
|
|
||||||
|
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||||
|
# TODO: Comments
|
||||||
|
factory_kwargs = dict(n_agents=1,
|
||||||
|
max_steps=400,
|
||||||
|
parse_doors=True,
|
||||||
|
level_name='rooms',
|
||||||
|
doors_have_area=True, #
|
||||||
|
verbose=False,
|
||||||
|
mv_prop=move_props, # See Above
|
||||||
|
obs_prop=obs_props, # See Above
|
||||||
|
done_at_collision=True,
|
||||||
|
item_prop=item_props
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||||
|
agent_kwargs = dict()
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Run the Training
|
||||||
|
for seed in range(n_seeds):
|
||||||
|
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||||
|
env_kwargs = factory_kwargs.copy()
|
||||||
|
env_kwargs.update(env_seed=seed)
|
||||||
|
|
||||||
|
# Output folder
|
||||||
|
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||||
|
seed_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Parameter Storage
|
||||||
|
param_path = seed_path / env_params_json
|
||||||
|
# Observation (measures) Storage
|
||||||
|
monitor_path = seed_path / 'monitor.pick'
|
||||||
|
recorder_path = seed_path / 'recorder.json'
|
||||||
|
# Model save Path for the trained model
|
||||||
|
model_save_path = seed_path / f'model.zip'
|
||||||
|
|
||||||
|
# Env Init & Model kwargs definition
|
||||||
|
with ItemFactory(**env_kwargs) as env_factory:
|
||||||
|
|
||||||
|
# EnvMonitor Init
|
||||||
|
env_monitor_callback = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# EnvRecorder Init
|
||||||
|
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||||
|
|
||||||
|
# Model Init
|
||||||
|
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
|
# Model train
|
||||||
|
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 3. Save env and agent for later analysis.
|
||||||
|
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||||
|
model.named_observation_space = env_factory.named_observation_space
|
||||||
|
model.named_action_space = env_factory.named_action_space
|
||||||
|
model.save(model_save_path)
|
||||||
|
env_factory.save_params(param_path)
|
||||||
|
env_monitor_callback.save_run(monitor_path)
|
||||||
|
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||||
|
|
||||||
|
# Compare performance runs, for each seed within a model
|
||||||
|
try:
|
||||||
|
compare_seed_runs(exp_path, use_tex=False)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Train ends here ############################################################
|
||||||
|
|
||||||
|
# Evaluation starts here #####################################################
|
||||||
|
# First Iterate over every model and monitor "as trained"
|
||||||
|
print('Start Measurement Tracking')
|
||||||
|
# For trained policy in study_root_path / identifier
|
||||||
|
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||||
|
|
||||||
|
# retrieve model class
|
||||||
|
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||||
|
# Load the agent agent
|
||||||
|
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||||
|
# Load old env kwargs
|
||||||
|
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||||
|
env_kwargs = simplejson.load(f)
|
||||||
|
# Make the env stop ar collisions
|
||||||
|
# (you only want to have a single collision per episode hence the statistics)
|
||||||
|
env_kwargs.update(done_at_collision=True)
|
||||||
|
|
||||||
|
# Init Env
|
||||||
|
with ItemFactory(**env_kwargs) as env_factory:
|
||||||
|
monitored_env_factory = EnvMonitor(env_factory)
|
||||||
|
|
||||||
|
# Evaluation Loop for i in range(n Episodes)
|
||||||
|
for episode in range(100):
|
||||||
|
# noinspection PyRedeclaration
|
||||||
|
env_state = monitored_env_factory.reset()
|
||||||
|
rew, done_bool = 0, False
|
||||||
|
while not done_bool:
|
||||||
|
action = model.predict(env_state, deterministic=True)[0]
|
||||||
|
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||||
|
rew += step_r
|
||||||
|
if done_bool:
|
||||||
|
break
|
||||||
|
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||||
|
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||||
|
print('Measurements Done')
|
@ -4,9 +4,9 @@ from pathlib import Path
|
|||||||
import yaml
|
import yaml
|
||||||
from stable_baselines3 import A2C, PPO, DQN
|
from stable_baselines3 import A2C, PPO, DQN
|
||||||
|
|
||||||
from environments.factory.factory_dirt import Constants as c
|
from environments.factory.additional.dirt.dirt_util import Constants
|
||||||
|
|
||||||
from environments.factory.factory_dirt import DirtFactory
|
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||||
from environments.logging.envmonitor import EnvMonitor
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.logging.recorder import EnvRecorder
|
from environments.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ if __name__ == '__main__':
|
|||||||
seed = 13
|
seed = 13
|
||||||
n_agents = 1
|
n_agents = 1
|
||||||
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||||
out_path = Path('study_out/reload')
|
out_path = Path('quickstart/combinations/single_agent_train_dirt_env_1659374984/PPO_DirtFactory_1659374984/0_PPO_DirtFactory_1659374984/')
|
||||||
model_path = out_path
|
model_path = out_path
|
||||||
|
|
||||||
with (out_path / f'env_params.json').open('r') as f:
|
with (out_path / f'env_params.json').open('r') as f:
|
||||||
@ -62,7 +62,7 @@ if __name__ == '__main__':
|
|||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
try:
|
try:
|
||||||
door = next(x for x in env.unwrapped.unwrapped[c.DOORS] if x.is_open)
|
door = next(x for x in env.unwrapped.unwrapped[Constants.DOORS] if x.is_open)
|
||||||
print('openDoor found')
|
print('openDoor found')
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
|
@ -19,9 +19,11 @@ import simplejson
|
|||||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
|
from environments.factory.dirt_util import DirtProperties
|
||||||
from environments.factory.combined_factories import DirtItemFactory
|
from environments.factory.combined_factories import DirtItemFactory
|
||||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
from environments.factory.factory_item import ItemFactory
|
||||||
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
from environments.logging.envmonitor import EnvMonitor
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
import pickle
|
import pickle
|
||||||
|
@ -20,9 +20,12 @@ import simplejson
|
|||||||
from environments.helpers import ActionTranslator, ObservationTranslator
|
from environments.helpers import ActionTranslator, ObservationTranslator
|
||||||
from environments.logging.recorder import EnvRecorder
|
from environments.logging.recorder import EnvRecorder
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
from environments.factory.factory_dirt import DirtFactory
|
||||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
from environments.factory.dirt_util import DirtProperties
|
||||||
from environments.factory.factory_dest import DestProperties, DestFactory, DestModeOptions
|
from environments.factory.factory_item import ItemFactory
|
||||||
|
from environments.factory.additional.item.item_util import ItemProperties
|
||||||
|
from environments.factory.factory_dest import DestFactory
|
||||||
|
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||||
from environments.factory.combined_factories import DirtDestItemFactory
|
from environments.factory.combined_factories import DirtDestItemFactory
|
||||||
from environments.logging.envmonitor import EnvMonitor
|
from environments.logging.envmonitor import EnvMonitor
|
||||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||||
|
Loading…
x
Reference in New Issue
Block a user