Compare commits
8 Commits
jannis_exp
...
main
Author | SHA1 | Date | |
---|---|---|---|
bcbd4a8078 | |||
bdbd1c4e25 | |||
6c2df735d4 | |||
4f3924d3ab | |||
6a24e7b518 | |||
e7461d7dcf | |||
33f144fc93 | |||
0218f8f4e9 |
@ -11,7 +11,7 @@ from environments.helpers import EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
DIRT = 'DirtPile'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
|
@ -2,8 +2,11 @@ def make(env_name, pomdp_r=2, max_steps=400, stack_n_frames=3, n_agents=1, indiv
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from environments.factory.combined_factories import DirtItemFactory
|
||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory, RewardsDirt
|
||||
from environments.factory.factory_item import ItemFactory
|
||||
from environments.factory.additional.item.item_util import ItemProperties
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.dirt_util import DirtProperties
|
||||
from environments.factory.dirt_util import RewardsDirt
|
||||
from environments.utility_classes import AgentRenderOptions
|
||||
|
||||
with (Path(__file__).parent / 'levels' / 'parameters' / f'{env_name}.yaml').open('r') as stream:
|
||||
|
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
0
environments/factory/additional/btry/__init__.py
Normal file
41
environments/factory/additional/btry/btry_collections.py
Normal file
41
environments/factory/additional/btry/btry_collections.py
Normal file
@ -0,0 +1,41 @@
|
||||
from environments.factory.additional.btry.btry_objects import Battery, ChargePod
|
||||
from environments.factory.base.registers import EnvObjectCollection, EntityCollection
|
||||
|
||||
|
||||
class Batteries(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = Battery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Batteries, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(batteries)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class ChargePods(EntityCollection):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
_stateless_entities = True
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
60
environments/factory/additional/btry/btry_objects.py
Normal file
60
environments/factory/additional/btry/btry_objects.py
Normal file
@ -0,0 +1,60 @@
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import BoundingMixin, EnvObject, Entity
|
||||
from environments.factory.additional.btry.btry_util import Constants as c
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
return self.charge_level == 0
|
||||
|
||||
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
30
environments/factory/additional/btry/btry_util.py
Normal file
30
environments/factory/additional/btry/btry_util.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class RewardsBtry(NamedTuple):
|
||||
CHARGE_VALID: float = 0.1
|
||||
CHARGE_FAIL: float = -0.1
|
||||
BATTERY_DISCHARGED: float = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
initial_charge: float = 0.8 #
|
||||
charge_rate: float = 0.4 #
|
||||
charge_locations: int = 20 #
|
||||
per_action_costs: Union[dict, float] = 0.02
|
||||
done_when_discharged: bool = False
|
||||
multi_charge: bool = False
|
139
environments/factory/additional/btry/factory_battery.py
Normal file
139
environments/factory/additional/btry/factory_battery.py
Normal file
@ -0,0 +1,139 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.additional.btry.btry_collections import Batteries, ChargePods
|
||||
from environments.factory.additional.btry.btry_util import Constants, Actions, RewardsBtry, BatteryProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
|
||||
class BatteryFactory(BaseFactory):
|
||||
|
||||
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_btry: RewardsBtry = RewardsBtry(),
|
||||
**kwargs):
|
||||
if isinstance(btry_prop, dict):
|
||||
btry_prop = BatteryProperties(**btry_prop)
|
||||
if isinstance(rewards_btry, dict):
|
||||
rewards_btry = RewardsBtry(**rewards_btry)
|
||||
self.btry_prop = btry_prop
|
||||
self.rewards_dest = rewards_btry
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
def entities_hook(self):
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||
charge_pods = ChargePods.from_tiles(
|
||||
empty_tiles, self._level_shape,
|
||||
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
||||
multi_charge=self.btry_prop.multi_charge)
|
||||
)
|
||||
|
||||
batteries = Batteries(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
|
||||
for agent in self[c.AGENT]:
|
||||
if isinstance(self.btry_prop.per_action_costs, dict):
|
||||
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
||||
else:
|
||||
energy_consumption = self.btry_prop.per_action_costs
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
return super_reward_info
|
||||
|
||||
def do_charge_action(self, agent) -> (dict, dict):
|
||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
if valid:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||
reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CHARGE:
|
||||
action_result = self.do_charge_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def reset_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).reset_hook()
|
||||
# There is Nothing to reset.
|
||||
return super_reward_info
|
||||
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||
if super_done:
|
||||
return super_done, super_dict
|
||||
else:
|
||||
if self.btry_prop.done_when_discharged:
|
||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||
super_dict.update(DISCHARGE_DONE=1)
|
||||
return btry_done, super_dict
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
return super_done, super_dict
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
reward_event_list = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_list.append({'value': self.rewards_dest.BATTERY_DISCHARGED,
|
||||
'reason': c.BATTERY_DISCHARGED,
|
||||
'info': info_dict}
|
||||
)
|
||||
else:
|
||||
# All Fine
|
||||
pass
|
||||
return reward_event_list
|
||||
|
||||
def render_assets_hook(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_assets_hook()
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
@ -1,12 +1,15 @@
|
||||
import random
|
||||
|
||||
from environments.factory.factory_battery import BatteryFactory, BatteryProperties
|
||||
from environments.factory.factory_dest import DestFactory
|
||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||
from environments.factory.factory_item import ItemFactory
|
||||
|
||||
|
||||
# noinspection PyAbstractClass
|
||||
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||
from environments.factory.additional.btry.factory_battery import BatteryFactory
|
||||
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||
from environments.factory.additional.item.factory_item import ItemFactory
|
||||
|
||||
|
||||
class DirtItemFactory(ItemFactory, DirtFactory):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -24,6 +27,12 @@ class DirtDestItemFactory(ItemFactory, DirtFactory, DestFactory):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# noinspection PyAbstractClass
|
||||
class DestBatteryFactory(BatteryFactory, DestFactory):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.utility_classes import AgentRenderOptions as ARO, ObservationProperties
|
||||
|
0
environments/factory/additional/dest/__init__.py
Normal file
0
environments/factory/additional/dest/__init__.py
Normal file
38
environments/factory/additional/dest/dest_collections.py
Normal file
38
environments/factory/additional/dest/dest_collections.py
Normal file
@ -0,0 +1,38 @@
|
||||
from environments.factory.base.registers import EntityCollection
|
||||
from environments.factory.additional.dest.dest_util import Constants as c
|
||||
from environments.factory.additional.dest.dest_enitites import Destination
|
||||
|
||||
|
||||
class Destinations(EntityCollection):
|
||||
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_blocking_light = False
|
||||
self.can_be_shadowed = False
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL
|
||||
# ToDo: Switch to new Style Array Put
|
||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
return super(Destinations, self).__repr__()
|
||||
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def __repr__(self):
|
||||
return super(ReachedDestinations, self).__repr__()
|
45
environments/factory/additional/dest/dest_enitites.py
Normal file
45
environments/factory/additional/dest/dest_enitites.py
Normal file
@ -0,0 +1,45 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from environments.factory.base.objects import Entity, Agent
|
||||
from environments.factory.additional.dest.dest_util import Constants as c
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
|
||||
@property
|
||||
def any_agent_has_dwelled(self):
|
||||
return bool(len(self._per_agent_times))
|
||||
|
||||
@property
|
||||
def currently_dwelling_names(self):
|
||||
return self._per_agent_times.keys()
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.DESTINATION
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
def leave(self, agent: Agent):
|
||||
del self._per_agent_times[agent.name]
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
return self._per_agent_times[agent.name] < self.dwell_time
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
state_summary = super().summarize_state()
|
||||
state_summary.update(per_agent_times=[
|
||||
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
|
||||
return state_summary
|
41
environments/factory/additional/dest/dest_util.py
Normal file
41
environments/factory/additional/dest/dest_util.py
Normal file
@ -0,0 +1,41 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Destination Env
|
||||
DEST = 'Destination'
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
DEST_REACHED = 'ReachedDestination'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class RewardsDest(NamedTuple):
|
||||
|
||||
WAIT_VALID: float = 0.1
|
||||
WAIT_FAIL: float = -0.1
|
||||
DEST_REACHED: float = 5.0
|
||||
|
||||
|
||||
class DestModeOptions(object):
|
||||
DONE = 'DONE'
|
||||
GROUPED = 'GROUPED'
|
||||
PER_DEST = 'PER_DEST'
|
||||
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
|
||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
@ -1,132 +1,19 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
from typing import List, Union, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.additional.dest.dest_collections import Destinations, ReachedDestinations
|
||||
from environments.factory.additional.dest.dest_enitites import Destination
|
||||
from environments.factory.additional.dest.dest_util import Constants, Actions, RewardsDest, DestModeOptions, \
|
||||
DestProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Destination Env
|
||||
DEST = 'Destination'
|
||||
DESTINATION = 1
|
||||
DESTINATION_DONE = 0.5
|
||||
DEST_REACHED = 'ReachedDestination'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
WAIT_ON_DEST = 'WAIT'
|
||||
|
||||
|
||||
class RewardsDest(NamedTuple):
|
||||
|
||||
WAIT_VALID: float = 0.1
|
||||
WAIT_FAIL: float = -0.1
|
||||
DEST_REACHED: float = 5.0
|
||||
|
||||
|
||||
class Destination(Entity):
|
||||
|
||||
@property
|
||||
def any_agent_has_dwelled(self):
|
||||
return bool(len(self._per_agent_times))
|
||||
|
||||
@property
|
||||
def currently_dwelling_names(self):
|
||||
return self._per_agent_times.keys()
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.DESTINATION
|
||||
|
||||
def __init__(self, *args, dwell_time: int = 0, **kwargs):
|
||||
super(Destination, self).__init__(*args, **kwargs)
|
||||
self.dwell_time = dwell_time
|
||||
self._per_agent_times = defaultdict(lambda: dwell_time)
|
||||
|
||||
def do_wait_action(self, agent: Agent):
|
||||
self._per_agent_times[agent.name] -= 1
|
||||
return c.VALID
|
||||
|
||||
def leave(self, agent: Agent):
|
||||
del self._per_agent_times[agent.name]
|
||||
|
||||
@property
|
||||
def is_considered_reached(self):
|
||||
agent_at_position = any(c.AGENT.lower() in x.name.lower() for x in self.tile.guests_that_can_collide)
|
||||
return (agent_at_position and not self.dwell_time) or any(x == 0 for x in self._per_agent_times.values())
|
||||
|
||||
def agent_is_dwelling(self, agent: Agent):
|
||||
return self._per_agent_times[agent.name] < self.dwell_time
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
state_summary = super().summarize_state(n_steps=n_steps)
|
||||
state_summary.update(per_agent_times=self._per_agent_times)
|
||||
return state_summary
|
||||
|
||||
|
||||
class Destinations(EntityRegister):
|
||||
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_blocking_light = False
|
||||
self.can_be_shadowed = False
|
||||
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL
|
||||
# ToDo: Switch to new Style Array Put
|
||||
# indices = list(zip(range(len(cls)), *zip(*[x.pos for x in cls])))
|
||||
# np.put(cls._array, [np.ravel_multi_index(x, cls._array.shape) for x in indices], cls.encodings)
|
||||
for item in self:
|
||||
if item.pos != c.NO_POS:
|
||||
self._array[0, item.x, item.y] = item.encoding
|
||||
return self._array
|
||||
|
||||
def __repr__(self):
|
||||
super(Destinations, self).__repr__()
|
||||
|
||||
|
||||
class ReachedDestinations(Destinations):
|
||||
_accepted_objects = Destination
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReachedDestinations, self).__init__(*args, **kwargs)
|
||||
self.can_be_shadowed = False
|
||||
self.is_blocking_light = False
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
|
||||
|
||||
class DestModeOptions(object):
|
||||
DONE = 'DONE'
|
||||
GROUPED = 'GROUPED'
|
||||
PER_DEST = 'PER_DEST'
|
||||
|
||||
|
||||
class DestProperties(NamedTuple):
|
||||
n_dests: int = 1 # How many destinations are there
|
||||
dwell_time: int = 0 # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency: int = 0
|
||||
spawn_in_other_zone: bool = True #
|
||||
spawn_mode: str = DestModeOptions.DONE
|
||||
|
||||
assert dwell_time >= 0, 'dwell_time cannot be < 0!'
|
||||
assert spawn_frequency >= 0, 'spawn_frequency cannot be < 0!'
|
||||
assert n_dests >= 0, 'n_destinations cannot be < 0!'
|
||||
assert (spawn_mode == DestModeOptions.DONE) != bool(spawn_frequency)
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
@ -135,7 +22,7 @@ a = Actions
|
||||
class DestFactory(BaseFactory):
|
||||
# noinspection PyMissingConstructor
|
||||
|
||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||
def __init__(self, *args, dest_prop: DestProperties = DestProperties(), rewards_dest: RewardsDest = RewardsDest(),
|
||||
env_seed=time.time_ns(), **kwargs):
|
||||
if isinstance(dest_prop, dict):
|
||||
dest_prop = DestProperties(**dest_prop)
|
||||
@ -151,6 +38,7 @@ class DestFactory(BaseFactory):
|
||||
def actions_hook(self) -> Union[Action, List[Action]]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
super_actions = super().actions_hook
|
||||
# If targets are considers reached after some time, agents need an action for that.
|
||||
if self.dest_prop.dwell_time:
|
||||
super_actions.append(Action(enum_ident=a.WAIT_ON_DEST))
|
||||
return super_actions
|
||||
@ -207,14 +95,14 @@ class DestFactory(BaseFactory):
|
||||
if destinations_to_spawn:
|
||||
n_dest_to_spawn = len(destinations_to_spawn)
|
||||
if self.dest_prop.spawn_mode != DestModeOptions.GROUPED:
|
||||
destinations = [Destination(tile, c.DEST) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].add_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
elif self.dest_prop.spawn_mode == DestModeOptions.GROUPED and n_dest_to_spawn == self.dest_prop.n_dests:
|
||||
destinations = [Destination(tile, self[c.DEST]) for tile in self[c.FLOOR].empty_tiles[:n_dest_to_spawn]]
|
||||
self[c.DEST].register_additional_items(destinations)
|
||||
self[c.DEST].add_additional_items(destinations)
|
||||
for dest in destinations_to_spawn:
|
||||
del self._dest_spawn_timer[dest]
|
||||
self.print(f'{n_dest_to_spawn} new destinations have been spawned')
|
||||
@ -229,9 +117,10 @@ class DestFactory(BaseFactory):
|
||||
super_reward_info = super().step_hook()
|
||||
for key, val in self._dest_spawn_timer.items():
|
||||
self._dest_spawn_timer[key] = min(self.dest_prop.spawn_frequency, self._dest_spawn_timer[key] + 1)
|
||||
|
||||
for dest in list(self[c.DEST].values()):
|
||||
if dest.is_considered_reached:
|
||||
dest.change_register(self[c.DEST])
|
||||
dest.change_parent_collection(self[c.DEST_REACHED])
|
||||
self._dest_spawn_timer[dest.name] = 0
|
||||
self.print(f'{dest.name} is reached now, removing...')
|
||||
else:
|
||||
@ -251,18 +140,19 @@ class DestFactory(BaseFactory):
|
||||
additional_observations.update({c.DEST: self[c.DEST].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
# noinspection PyUnresolvedReferences
|
||||
reward_event_dict = super().per_agent_reward_hook(agent)
|
||||
reward_event_list = super().per_agent_reward_hook(agent)
|
||||
if len(self[c.DEST_REACHED]):
|
||||
for reached_dest in list(self[c.DEST_REACHED]):
|
||||
if agent.pos == reached_dest.pos:
|
||||
self.print(f'{agent.name} just reached destination at {agent.pos}')
|
||||
self[c.DEST_REACHED].delete_env_object(reached_dest)
|
||||
info_dict = {f'{agent.name}_{c.DEST_REACHED}': 1}
|
||||
reward_event_dict.update({c.DEST_REACHED: {'reward': self.rewards_dest.DEST_REACHED,
|
||||
'info': info_dict}})
|
||||
return reward_event_dict
|
||||
reward_event_list.append({'value': self.rewards_dest.DEST_REACHED,
|
||||
'reason': c.DEST_REACHED,
|
||||
'info': info_dict})
|
||||
return reward_event_list
|
||||
|
||||
def render_assets_hook(self, mode='human'):
|
||||
# noinspection PyUnresolvedReferences
|
0
environments/factory/additional/dirt/__init__.py
Normal file
0
environments/factory/additional/dirt/__init__.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
42
environments/factory/additional/dirt/dirt_collections.py
Normal file
@ -0,0 +1,42 @@
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import EntityCollection
|
||||
from environments.factory.additional.dirt.dirt_util import Constants as c
|
||||
|
||||
|
||||
class DirtPiles(EntityCollection):
|
||||
|
||||
_accepted_objects = DirtPile
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return sum([dirt.amount for dirt in self])
|
||||
|
||||
@property
|
||||
def dirt_properties(self):
|
||||
return self._dirt_properties
|
||||
|
||||
def __init__(self, dirt_properties, *args):
|
||||
super(DirtPiles, self).__init__(*args)
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = DirtPile(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.add_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtPiles, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
26
environments/factory/additional/dirt/dirt_entity.py
Normal file
@ -0,0 +1,26 @@
|
||||
from environments.factory.base.objects import Entity
|
||||
|
||||
|
||||
class DirtPile(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differntly
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, amount=None, **kwargs):
|
||||
super(DirtPile, self).__init__(*args, **kwargs)
|
||||
self._amount = amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
30
environments/factory/additional/dirt/dirt_util.py
Normal file
30
environments/factory/additional/dirt/dirt_util.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'DirtPile'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
|
||||
class RewardsDirt(NamedTuple):
|
||||
CLEAN_UP_VALID: float = 0.5
|
||||
CLEAN_UP_FAIL: float = -0.1
|
||||
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||
|
||||
|
||||
class DirtProperties(NamedTuple):
|
||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||
done_when_clean: bool = True
|
@ -1,111 +1,22 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
from typing import List, Union, Dict
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from algorithms.TSP_dirt_agent import TSPDirtAgent
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.additional.dirt.dirt_util import Constants, Actions, RewardsDirt, DirtProperties
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.utility_classes import ObservationProperties
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
DIRT = 'Dirt'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CLEAN_UP = 'do_cleanup_action'
|
||||
|
||||
|
||||
class RewardsDirt(NamedTuple):
|
||||
CLEAN_UP_VALID: float = 0.5
|
||||
CLEAN_UP_FAIL: float = -0.1
|
||||
CLEAN_UP_LAST_PIECE: float = 4.5
|
||||
|
||||
|
||||
class DirtProperties(NamedTuple):
|
||||
initial_dirt_ratio: float = 0.3 # On INIT, on max how many tiles does the dirt spawn in percent.
|
||||
initial_dirt_spawn_r_var: float = 0.05 # How much does the dirt spawn amount vary?
|
||||
clean_amount: float = 1 # How much does the robot clean with one actions.
|
||||
max_spawn_ratio: float = 0.20 # On max how many tiles does the dirt spawn in percent.
|
||||
max_spawn_amount: float = 0.3 # How much dirt does spawn per tile at max.
|
||||
spawn_frequency: int = 0 # Spawn Frequency in Steps.
|
||||
max_local_amount: int = 2 # Max dirt amount per tile.
|
||||
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||
dirt_smear_amount: float = 0.2 # Agents smear dirt, when not cleaning up in place.
|
||||
done_when_clean: bool = True
|
||||
|
||||
|
||||
class Dirt(Entity):
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return self._amount
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differntly
|
||||
return self._amount
|
||||
|
||||
def __init__(self, *args, amount=None, **kwargs):
|
||||
super(Dirt, self).__init__(*args, **kwargs)
|
||||
self._amount = amount
|
||||
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
self._register.notify_change_to_value(self)
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
|
||||
|
||||
class DirtRegister(EntityRegister):
|
||||
|
||||
_accepted_objects = Dirt
|
||||
|
||||
@property
|
||||
def amount(self):
|
||||
return sum([dirt.amount for dirt in self])
|
||||
|
||||
@property
|
||||
def dirt_properties(self):
|
||||
return self._dirt_properties
|
||||
|
||||
def __init__(self, dirt_properties, *args):
|
||||
super(DirtRegister, self).__init__(*args)
|
||||
self._dirt_properties: DirtProperties = dirt_properties
|
||||
|
||||
def spawn_dirt(self, then_dirty_tiles) -> bool:
|
||||
if isinstance(then_dirty_tiles, Floor):
|
||||
then_dirty_tiles = [then_dirty_tiles]
|
||||
for tile in then_dirty_tiles:
|
||||
if not self.amount > self.dirt_properties.max_global_amount:
|
||||
dirt = self.by_pos(tile.pos)
|
||||
if dirt is None:
|
||||
dirt = Dirt(tile, self, amount=self.dirt_properties.max_spawn_amount)
|
||||
self.register_item(dirt)
|
||||
else:
|
||||
new_value = dirt.amount + self.dirt_properties.max_spawn_amount
|
||||
dirt.set_new_amount(min(new_value, self.dirt_properties.max_local_amount))
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
return c.VALID
|
||||
|
||||
def __repr__(self):
|
||||
s = super(DirtRegister, self).__repr__()
|
||||
return f'{s[:-1]}, {self.amount})'
|
||||
|
||||
|
||||
def softmax(x):
|
||||
"""Compute softmax values for each sets of scores in x."""
|
||||
e_x = np.exp(x - np.max(x))
|
||||
@ -132,7 +43,7 @@ class DirtFactory(BaseFactory):
|
||||
@property
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook
|
||||
dirt_register = DirtRegister(self.dirt_prop, self._level_shape)
|
||||
dirt_register = DirtPiles(self.dirt_prop, self._level_shape)
|
||||
super_entities.update(({c.DIRT: dirt_register}))
|
||||
return super_entities
|
||||
|
||||
@ -146,7 +57,7 @@ class DirtFactory(BaseFactory):
|
||||
self.dirt_prop = dirt_prop
|
||||
self.rewards_dirt = rewards_dirt
|
||||
self._dirt_rng = np.random.default_rng(env_seed)
|
||||
self._dirt: DirtRegister
|
||||
self._dirt: DirtPiles
|
||||
kwargs.update(env_seed=env_seed)
|
||||
# TODO: Reset ---> document this
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -185,7 +96,7 @@ class DirtFactory(BaseFactory):
|
||||
def trigger_dirt_spawn(self, initial_spawn=False):
|
||||
dirt_rng = self._dirt_rng
|
||||
free_for_dirt = [x for x in self[c.FLOOR]
|
||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), Dirt))
|
||||
if len(x.guests) == 0 or (len(x.guests) == 1 and isinstance(next(y for y in x.guests), DirtPile))
|
||||
]
|
||||
self._dirt_rng.shuffle(free_for_dirt)
|
||||
if initial_spawn:
|
||||
@ -198,21 +109,21 @@ class DirtFactory(BaseFactory):
|
||||
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super().step_hook()
|
||||
# if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||
# for agent in self[c.AGENT]:
|
||||
# if agent.temp_valid and agent.last_pos != c.NO_POS:
|
||||
# if self._actions.is_moving_action(agent.temp_action):
|
||||
# if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||
# if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||
# old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||
# if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
||||
# new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||
# else:
|
||||
# if self[c.DIRT].spawn_dirt(agent.tile):
|
||||
# new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||
# new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||
if smear_amount := self.dirt_prop.dirt_smear_amount:
|
||||
for agent in self[c.AGENT]:
|
||||
if agent.step_result['action_valid'] and agent.last_pos != c.NO_POS:
|
||||
if self._actions.is_moving_action(agent.step_result['action_name']):
|
||||
if old_pos_dirt := self[c.DIRT].by_pos(agent.last_pos):
|
||||
if smeared_dirt := round(old_pos_dirt.amount * smear_amount, 2):
|
||||
old_pos_dirt.set_new_amount(max(0, old_pos_dirt.amount-smeared_dirt))
|
||||
if new_pos_dirt := self[c.DIRT].by_pos(agent.pos):
|
||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||
else:
|
||||
if self[c.DIRT].spawn_dirt(agent.tile):
|
||||
new_pos_dirt = self[c.DIRT].by_pos(agent.pos)
|
||||
new_pos_dirt.set_new_amount(max(0, new_pos_dirt.amount + smeared_dirt))
|
||||
if self._next_dirt_spawn < 0:
|
||||
pass # No Dirt Spawn
|
||||
pass # No DirtPile Spawn
|
||||
elif not self._next_dirt_spawn:
|
||||
self.trigger_dirt_spawn()
|
||||
self._next_dirt_spawn = self.dirt_prop.spawn_frequency
|
||||
@ -248,8 +159,8 @@ class DirtFactory(BaseFactory):
|
||||
additional_observations.update({c.DIRT: self[c.DIRT].as_array()})
|
||||
return additional_observations
|
||||
|
||||
def gather_additional_info(self, agent: Agent) -> dict:
|
||||
event_reward_dict = super().per_agent_reward_hook(agent)
|
||||
def post_step_hook(self) -> List[Dict[str, int]]:
|
||||
super_post_step = super(DirtFactory, self).post_step_hook()
|
||||
info_dict = dict()
|
||||
|
||||
dirt = [dirt.amount for dirt in self[c.DIRT]]
|
||||
@ -264,8 +175,8 @@ class DirtFactory(BaseFactory):
|
||||
info_dict.update(dirt_amount=current_dirt_amount)
|
||||
info_dict.update(dirty_tile_count=dirty_tile_count)
|
||||
|
||||
event_reward_dict.update({'info': info_dict})
|
||||
return event_reward_dict
|
||||
super_post_step.append(info_dict)
|
||||
return super_post_step
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -304,7 +215,6 @@ if __name__ == '__main__':
|
||||
# inject_agents=[TSPDirtAgent],
|
||||
)
|
||||
|
||||
factory.save_params(Path('rewards_param'))
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
0
environments/factory/additional/item/__init__.py
Normal file
0
environments/factory/additional/item/__init__.py
Normal file
@ -1,179 +1,16 @@
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import List, Union, NamedTuple, Dict
|
||||
from typing import List, Union, Dict
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from environments.factory.additional.item.item_collections import Items, Inventories, DropOffLocations
|
||||
from environments.factory.additional.item.item_util import Constants, Actions, RewardsItem, ItemProperties
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Entity, Action, Floor
|
||||
from environments.factory.base.registers import Entities, EntityRegister, BoundEnvObjRegister, ObjectRegister
|
||||
from environments.factory.base.objects import Agent, Action
|
||||
from environments.factory.base.registers import Entities
|
||||
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class RewardsItem(NamedTuple):
|
||||
DROP_OFF_VALID: float = 0.1
|
||||
DROP_OFF_FAIL: float = -0.1
|
||||
PICK_UP_FAIL: float = -0.1
|
||||
PICK_UP_VALID: float = 0.1
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._auto_despawn = -1
|
||||
|
||||
@property
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
return 1
|
||||
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
assert self._register.__class__.__name__ != ItemRegister.__class__
|
||||
self._tile = no_pos_tile
|
||||
|
||||
|
||||
class ItemRegister(EntityRegister):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.register_additional_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundEnvObjRegister):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.capacity = capacity
|
||||
|
||||
def as_array(self):
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
return super(Inventory, self).as_array()
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
|
||||
class Inventories(ObjectRegister):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
|
||||
def __init__(self, obs_shape, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
def as_array(self):
|
||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(inventories)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return {key: val.summarize_states(**kwargs) for key, val in self.items()}
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return Constants.ITEM_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item):
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return c.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||
return c.VALID
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
|
||||
|
||||
class DropOffLocations(EntityRegister):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
@ -212,7 +49,7 @@ class ItemFactory(BaseFactory):
|
||||
entity_kwargs=dict(
|
||||
storage_size_until_full=self.item_prop.max_dropoff_storage_size)
|
||||
)
|
||||
item_register = ItemRegister(self._level_shape)
|
||||
item_register = Items(self._level_shape)
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.item_prop.n_items]
|
||||
item_register.spawn_items(empty_tiles)
|
||||
|
||||
@ -250,7 +87,7 @@ class ItemFactory(BaseFactory):
|
||||
reason=a.ITEM_ACTION, info=info_dict)
|
||||
return valid, reward
|
||||
elif item := self[c.ITEM].by_pos(agent.pos):
|
||||
item.change_register(inventory)
|
||||
item.change_parent_collection(inventory)
|
||||
item.set_tile_to(self._NO_POS_TILE)
|
||||
self.print(f'{agent.name} just picked up an item at {agent.pos}')
|
||||
info_dict = {f'{agent.name}_{a.ITEM_ACTION}_VALID': 1, f'{a.ITEM_ACTION}_VALID': 1}
|
89
environments/factory/additional/item/item_collections.py
Normal file
89
environments/factory/additional/item/item_collections.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Floor, Agent
|
||||
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
|
||||
from environments.factory.additional.item.item_entities import Item, DropOffLocation
|
||||
|
||||
|
||||
class Items(EntityCollection):
|
||||
|
||||
_accepted_objects = Item
|
||||
|
||||
def spawn_items(self, tiles: List[Floor]):
|
||||
items = [Item(tile, self) for tile in tiles]
|
||||
self.add_additional_items(items)
|
||||
|
||||
def despawn_items(self, items: List[Item]):
|
||||
items = [items] if isinstance(items, Item) else items
|
||||
for item in items:
|
||||
del self[item]
|
||||
|
||||
|
||||
class Inventory(BoundEnvObjCollection):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}({self._bound_entity.name})'
|
||||
|
||||
def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
|
||||
super(Inventory, self).__init__(agent, *args, is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.capacity = capacity
|
||||
|
||||
def as_array(self):
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
return super(Inventory, self).as_array()
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
attr_dict = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(items=[val.summarize_state(**kwargs) for key, val in self.items()]))
|
||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
||||
return attr_dict
|
||||
|
||||
def pop(self):
|
||||
item_to_pop = self[0]
|
||||
self.delete_env_object(item_to_pop)
|
||||
return item_to_pop
|
||||
|
||||
|
||||
class Inventories(ObjectCollection):
|
||||
|
||||
_accepted_objects = Inventory
|
||||
is_blocking_light = False
|
||||
can_be_shadowed = False
|
||||
|
||||
def __init__(self, obs_shape, *args, **kwargs):
|
||||
super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
|
||||
self._obs_shape = obs_shape
|
||||
|
||||
def as_array(self):
|
||||
return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])
|
||||
|
||||
def spawn_inventories(self, agents, capacity):
|
||||
inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
|
||||
for _, agent in enumerate(agents)]
|
||||
self.add_additional_items(inventories)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((inv for inv in self if inv.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def summarize_states(self, **kwargs):
|
||||
return [val.summarize_states(**kwargs) for key, val in self.items()]
|
||||
|
||||
|
||||
class DropOffLocations(EntityCollection):
|
||||
|
||||
_accepted_objects = DropOffLocation
|
||||
_stateless_entities = True
|
||||
|
57
environments/factory/additional/item/item_entities.py
Normal file
57
environments/factory/additional/item/item_entities.py
Normal file
@ -0,0 +1,57 @@
|
||||
from collections import deque
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.additional.item.item_util import Constants
|
||||
from environments.factory.base.objects import Entity
|
||||
|
||||
|
||||
class Item(Entity):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._auto_despawn = -1
|
||||
|
||||
@property
|
||||
def auto_despawn(self):
|
||||
return self._auto_despawn
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
# Edit this if you want items to be drawn in the ops differently
|
||||
return 1
|
||||
|
||||
def set_auto_despawn(self, auto_despawn):
|
||||
self._auto_despawn = auto_despawn
|
||||
|
||||
def set_tile_to(self, no_pos_tile):
|
||||
self._tile = no_pos_tile
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
super_summarization = super(Item, self).summarize_state()
|
||||
super_summarization.update(dict(auto_despawn=self.auto_despawn))
|
||||
return super_summarization
|
||||
|
||||
|
||||
class DropOffLocation(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return Constants.ITEM_DROP_OFF
|
||||
|
||||
def __init__(self, *args, storage_size_until_full: int = 5, auto_item_despawn_interval: int = 5, **kwargs):
|
||||
super(DropOffLocation, self).__init__(*args, **kwargs)
|
||||
self.auto_item_despawn_interval = auto_item_despawn_interval
|
||||
self.storage = deque(maxlen=storage_size_until_full or None)
|
||||
|
||||
def place_item(self, item: Item):
|
||||
if self.is_full:
|
||||
raise RuntimeWarning("There is currently no way to clear the storage or make it unfull.")
|
||||
return c.NOT_VALID
|
||||
else:
|
||||
self.storage.append(item)
|
||||
item.set_auto_despawn(self.auto_item_despawn_interval)
|
||||
return Constants.VALID
|
||||
|
||||
@property
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
31
environments/factory/additional/item/item_util.py
Normal file
31
environments/factory/additional/item/item_util.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from environments.helpers import Constants as BaseConstants, EnvActions as BaseActions
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
NO_ITEM = 0
|
||||
ITEM_DROP_OFF = 1
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
ITEM_ACTION = 'ITEMACTION'
|
||||
|
||||
|
||||
class RewardsItem(NamedTuple):
|
||||
DROP_OFF_VALID: float = 0.1
|
||||
DROP_OFF_FAIL: float = -0.1
|
||||
PICK_UP_FAIL: float = -0.1
|
||||
PICK_UP_VALID: float = 0.1
|
||||
|
||||
|
||||
class ItemProperties(NamedTuple):
|
||||
n_items: int = 5 # How many items are there at the same time
|
||||
spawn_frequency: int = 10 # Spawn Frequency in Steps
|
||||
n_drop_off_locations: int = 5 # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size: int = 0 # How many items are needed until the dropoff is full
|
||||
max_agent_inventory_capacity: int = 5 # How many items are needed until the agent inventory is full
|
@ -71,6 +71,12 @@ class BaseFactory(gym.Env):
|
||||
d['class_name'] = self.__class__.__name__
|
||||
return d
|
||||
|
||||
@property
|
||||
def summarize_header(self):
|
||||
summary_dict = self._summarize_state(stateless_entities=True)
|
||||
summary_dict.update(actions=self._actions.summarize())
|
||||
return summary_dict
|
||||
|
||||
def __enter__(self):
|
||||
return self if self.obs_prop.frames_to_stack == 0 else \
|
||||
MarlFrameStack(FrameStack(self, self.obs_prop.frames_to_stack))
|
||||
@ -144,9 +150,9 @@ class BaseFactory(gym.Env):
|
||||
# Objects
|
||||
self._entities = Entities()
|
||||
# Level
|
||||
|
||||
level_array = h.one_hot_level(self._parsed_level)
|
||||
level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=1)
|
||||
self._level_init_shape = level_array.shape
|
||||
level_array = np.pad(level_array, self.obs_prop.pomdp_r, 'constant', constant_values=c.OCCUPIED_CELL)
|
||||
|
||||
self._level_shape = level_array.shape
|
||||
self._obs_shape = self._level_shape if not self.obs_prop.pomdp_r else (self.pomdp_diameter, ) * 2
|
||||
@ -156,14 +162,14 @@ class BaseFactory(gym.Env):
|
||||
np.argwhere(level_array == c.OCCUPIED_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.WALLS: walls})
|
||||
self._entities.add_additional_items({c.WALLS: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_argwhere_coordinates(
|
||||
np.argwhere(level_array == c.FREE_CELL),
|
||||
self._level_shape
|
||||
)
|
||||
self._entities.register_additional_items({c.FLOOR: floor})
|
||||
self._entities.add_additional_items({c.FLOOR: floor})
|
||||
|
||||
# NOPOS
|
||||
self._NO_POS_TILE = Floor(c.NO_POS, None)
|
||||
@ -177,12 +183,13 @@ class BaseFactory(gym.Env):
|
||||
doors = Doors.from_tiles(door_tiles, self._level_shape, have_area=self.obs_prop.indicate_door_area,
|
||||
entity_kwargs=dict(context=floor)
|
||||
)
|
||||
self._entities.register_additional_items({c.DOORS: doors})
|
||||
self._entities.add_additional_items({c.DOORS: doors})
|
||||
|
||||
# Actions
|
||||
# TODO: Move this to Agent init, so that agents can have individual action sets.
|
||||
self._actions = Actions(self.mv_prop, can_use_doors=self.parse_doors)
|
||||
if additional_actions := self.actions_hook:
|
||||
self._actions.register_additional_items(additional_actions)
|
||||
self._actions.add_additional_items(additional_actions)
|
||||
|
||||
# Agents
|
||||
agents_to_spawn = self.n_agents-len(self._injected_agents)
|
||||
@ -196,10 +203,10 @@ class BaseFactory(gym.Env):
|
||||
if self._injected_agents:
|
||||
initialized_injections = list()
|
||||
for i, injection in enumerate(self._injected_agents):
|
||||
agents.register_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||
agents.add_item(injection(self, floor.empty_tiles[0], agents, static_problem=False))
|
||||
initialized_injections.append(agents[-1])
|
||||
self._initialized_injections = initialized_injections
|
||||
self._entities.register_additional_items({c.AGENT: agents})
|
||||
self._entities.add_additional_items({c.AGENT: agents})
|
||||
|
||||
if self.obs_prop.additional_agent_placeholder is not None:
|
||||
# TODO: Make this accept Lists for multiple placeholders
|
||||
@ -210,18 +217,18 @@ class BaseFactory(gym.Env):
|
||||
fill_value=self.obs_prop.additional_agent_placeholder)
|
||||
)
|
||||
|
||||
self._entities.register_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
self._entities.add_additional_items({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# Additional Entitites from SubEnvs
|
||||
if additional_entities := self.entities_hook:
|
||||
self._entities.register_additional_items(additional_entities)
|
||||
self._entities.add_additional_items(additional_entities)
|
||||
|
||||
if self.obs_prop.show_global_position_info:
|
||||
global_positions = GlobalPositions(self._level_shape)
|
||||
# This moved into the GlobalPosition object
|
||||
# obs_shape_2d = self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2)
|
||||
global_positions.spawn_global_position_objects(self[c.AGENT])
|
||||
self._entities.register_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
self._entities.add_additional_items({c.GLOBAL_POSITION: global_positions})
|
||||
|
||||
# Return
|
||||
return self._entities
|
||||
@ -308,7 +315,8 @@ class BaseFactory(gym.Env):
|
||||
info.update(self._summarize_state())
|
||||
|
||||
# Post step Hook for later use
|
||||
info.update(self.post_step_hook())
|
||||
for post_step_info in self.post_step_hook():
|
||||
info.update(post_step_info)
|
||||
|
||||
obs, _ = self._build_observations()
|
||||
|
||||
@ -367,14 +375,16 @@ class BaseFactory(gym.Env):
|
||||
agent_obs = global_agent_obs.copy()
|
||||
agent_obs[(0, *agent.pos)] -= agent.encoding
|
||||
else:
|
||||
agent_obs = global_agent_obs
|
||||
agent_obs = global_agent_obs.copy()
|
||||
else:
|
||||
# agent_obs == None!!!!!
|
||||
agent_obs = global_agent_obs
|
||||
|
||||
# Build Level Observations
|
||||
if self.obs_prop.render_agents == a_obs.LEVEL:
|
||||
assert agent_obs is not None
|
||||
lvl_obs = lvl_obs.copy()
|
||||
lvl_obs += global_agent_obs
|
||||
lvl_obs += agent_obs
|
||||
|
||||
obs_dict[c.WALLS] = lvl_obs
|
||||
if self.obs_prop.render_agents in [a_obs.SEPERATE, a_obs.COMBINED] and agent_obs is not None:
|
||||
@ -535,7 +545,7 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def _check_agent_move(self, agent, action: Action) -> (Floor, bool):
|
||||
# Actions
|
||||
x_diff, y_diff = h.ACTIONMAP[action.identifier]
|
||||
x_diff, y_diff = a.resolve_movement_action_to_coords(action.identifier)
|
||||
x_new = agent.x + x_diff
|
||||
y_new = agent.y + y_diff
|
||||
|
||||
@ -600,7 +610,9 @@ class BaseFactory(gym.Env):
|
||||
for reward in agent.step_result['rewards']:
|
||||
combined_info_dict.update(reward['info'])
|
||||
|
||||
# Combine Info dicts into a global one
|
||||
combined_info_dict = dict(combined_info_dict)
|
||||
|
||||
combined_info_dict.update(info)
|
||||
|
||||
global_reward_sum = sum(global_env_rewards)
|
||||
@ -616,9 +628,11 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def start_recording(self):
|
||||
self._record_episodes = True
|
||||
return self._record_episodes
|
||||
|
||||
def stop_recording(self):
|
||||
self._record_episodes = False
|
||||
return not self._record_episodes
|
||||
|
||||
# noinspection PyGlobalUndefined
|
||||
def render(self, mode='human'):
|
||||
@ -657,12 +671,12 @@ class BaseFactory(gym.Env):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _summarize_state(self):
|
||||
def _summarize_state(self, stateless_entities=False):
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
for entity_group in self._entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
|
||||
|
||||
if entity_group.is_stateless == stateless_entities:
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
@ -719,12 +733,12 @@ class BaseFactory(gym.Env):
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
return {}
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def post_step_hook(self) -> dict:
|
||||
return {}
|
||||
def post_step_hook(self) -> List[dict]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
|
@ -72,21 +72,21 @@ class EnvObject(Object):
|
||||
def encoding(self):
|
||||
return c.OCCUPIED_CELL
|
||||
|
||||
def __init__(self, register, **kwargs):
|
||||
def __init__(self, collection, **kwargs):
|
||||
super(EnvObject, self).__init__(**kwargs)
|
||||
self._register = register
|
||||
self._collection = collection
|
||||
|
||||
def change_register(self, register):
|
||||
register.register_item(self)
|
||||
self._register.delete_env_object(self)
|
||||
self._register = register
|
||||
return self._register == register
|
||||
def change_parent_collection(self, other_collection):
|
||||
other_collection.add_item(self)
|
||||
self._collection.delete_env_object(self)
|
||||
self._collection = other_collection
|
||||
return self._collection == other_collection
|
||||
# With Rendering
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
class Entity(EnvObject):
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, Dirt etc..."""
|
||||
"""Full Env Entity that lives on the env Grid. Doors, Items, DirtPile etc..."""
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
@ -113,13 +113,12 @@ class Entity(EnvObject):
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self, **_) -> dict:
|
||||
def summarize_state(self) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
def __repr__(self):
|
||||
return super(Entity, self).__repr__() + f'(@{self.pos})'
|
||||
# With Position in Env
|
||||
|
||||
|
||||
# TODO: Missing Documentation
|
||||
@ -153,7 +152,7 @@ class MoveableEntity(Entity):
|
||||
curr_tile.leave(self)
|
||||
self._tile = next_tile
|
||||
self._last_tile = curr_tile
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
@ -339,8 +338,8 @@ class Door(Entity):
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||
return state_dict
|
||||
|
||||
@ -371,13 +370,13 @@ class Door(Entity):
|
||||
def _open(self):
|
||||
self.connectivity.add_edges_from([(self.pos, x) for x in range(len(self.connectivity_subgroups))])
|
||||
self._state = c.OPEN_DOOR
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
self.time_to_close = self.auto_close_interval
|
||||
|
||||
def _close(self):
|
||||
self.connectivity.remove_node(self.pos)
|
||||
self._state = c.CLOSED_DOOR
|
||||
self._register.notify_change_to_value(self)
|
||||
self._collection.notify_change_to_value(self)
|
||||
|
||||
def is_linked(self, old_pos, new_pos):
|
||||
try:
|
||||
@ -403,7 +402,7 @@ class Agent(MoveableEntity):
|
||||
# if attr.startswith('temp'):
|
||||
self.step_result = None
|
||||
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
state_dict.update(valid=bool(self.step_result['action_valid']), action=str(self.step_result['action_name']))
|
||||
return state_dict
|
||||
|
@ -13,71 +13,76 @@ from environments import helpers as h
|
||||
from environments.helpers import Constants as c
|
||||
|
||||
##########################################################################
|
||||
# ##################### Base Register Definition ####################### #
|
||||
# ################## Base Collections Definition ####################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class ObjectRegister:
|
||||
class ObjectCollection:
|
||||
_accepted_objects = Object
|
||||
_stateless_entities = False
|
||||
|
||||
@property
|
||||
def is_stateless(self):
|
||||
return self._stateless_entities
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return f'{self.__class__.__name__}'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register = dict()
|
||||
self._collection = dict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._register)
|
||||
return len(self._collection)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.values())
|
||||
|
||||
def register_item(self, other: _accepted_objects):
|
||||
def add_item(self, other: _accepted_objects):
|
||||
assert isinstance(other, self._accepted_objects), f'All item names have to be of type ' \
|
||||
f'{self._accepted_objects}, ' \
|
||||
f'but were {other.__class__}.,'
|
||||
self._register.update({other.name: other})
|
||||
self._collection.update({other.name: other})
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: List[_accepted_objects]):
|
||||
def add_additional_items(self, others: List[_accepted_objects]):
|
||||
for other in others:
|
||||
self.register_item(other)
|
||||
self.add_item(other)
|
||||
return self
|
||||
|
||||
def keys(self):
|
||||
return self._register.keys()
|
||||
return self._collection.keys()
|
||||
|
||||
def values(self):
|
||||
return self._register.values()
|
||||
return self._collection.values()
|
||||
|
||||
def items(self):
|
||||
return self._register.items()
|
||||
return self._collection.items()
|
||||
|
||||
def _get_index(self, item):
|
||||
try:
|
||||
return next(i for i, v in enumerate(self._register.values()) if v == item)
|
||||
return next(i for i, v in enumerate(self._collection.values()) if v == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, (int, np.int64, np.int32)):
|
||||
if item < 0:
|
||||
item = len(self._register) - abs(item)
|
||||
item = len(self._collection) - abs(item)
|
||||
try:
|
||||
return next(v for i, v in enumerate(self._register.values()) if i == item)
|
||||
return next(v for i, v in enumerate(self._collection.values()) if i == item)
|
||||
except StopIteration:
|
||||
return None
|
||||
try:
|
||||
return self._register[item]
|
||||
return self._collection[item]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}[{self._register}]'
|
||||
return f'{self.__class__.__name__}[{self._collection}]'
|
||||
|
||||
|
||||
class EnvObjectRegister(ObjectRegister):
|
||||
class EnvObjectCollection(ObjectCollection):
|
||||
|
||||
_accepted_objects = EnvObject
|
||||
|
||||
@ -90,7 +95,7 @@ class EnvObjectRegister(ObjectRegister):
|
||||
is_blocking_light: bool = False,
|
||||
can_collide: bool = False,
|
||||
can_be_shadowed: bool = True, **kwargs):
|
||||
super(EnvObjectRegister, self).__init__(*args, **kwargs)
|
||||
super(EnvObjectCollection, self).__init__(*args, **kwargs)
|
||||
self._shape = obs_shape
|
||||
self._array = None
|
||||
self._individual_slices = individual_slices
|
||||
@ -99,8 +104,8 @@ class EnvObjectRegister(ObjectRegister):
|
||||
self.can_be_shadowed = can_be_shadowed
|
||||
self.can_collide = can_collide
|
||||
|
||||
def register_item(self, other: EnvObject):
|
||||
super(EnvObjectRegister, self).register_item(other)
|
||||
def add_item(self, other: EnvObject):
|
||||
super(EnvObjectCollection, self).add_item(other)
|
||||
if self._array is None:
|
||||
self._array = np.zeros((1, *self._shape))
|
||||
else:
|
||||
@ -116,8 +121,8 @@ class EnvObjectRegister(ObjectRegister):
|
||||
self._lazy_eval_transforms = []
|
||||
return self._array
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||
def summarize_states(self):
|
||||
return [entity.summarize_state() for entity in self.values()]
|
||||
|
||||
def notify_change_to_free(self, env_object: EnvObject):
|
||||
self._array_change_notifyer(env_object, value=c.FREE_CELL)
|
||||
@ -145,13 +150,13 @@ class EnvObjectRegister(ObjectRegister):
|
||||
if self._individual_slices:
|
||||
self._array = np.delete(self._array, idx, axis=0)
|
||||
else:
|
||||
self.notify_change_to_free(self._register[name])
|
||||
self.notify_change_to_free(self._collection[name])
|
||||
# Dirty Hack to check if not beeing subclassed. In that case we need to refresh the array since positions
|
||||
# in the observation array are result of enumeration. They can overide each other.
|
||||
# Todo: Find a better solution
|
||||
if not issubclass(self.__class__, EntityRegister) and issubclass(self.__class__, EnvObjectRegister):
|
||||
if not issubclass(self.__class__, EntityCollection) and issubclass(self.__class__, EnvObjectCollection):
|
||||
self._refresh_arrays()
|
||||
del self._register[name]
|
||||
del self._collection[name]
|
||||
|
||||
def delete_env_object(self, env_object: EnvObject):
|
||||
del self[env_object.name]
|
||||
@ -160,19 +165,19 @@ class EnvObjectRegister(ObjectRegister):
|
||||
del self[name]
|
||||
|
||||
|
||||
class EntityRegister(EnvObjectRegister, ABC):
|
||||
class EntityCollection(EnvObjectCollection, ABC):
|
||||
|
||||
_accepted_objects = Entity
|
||||
|
||||
@classmethod
|
||||
def from_tiles(cls, tiles, *args, entity_kwargs=None, **kwargs):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
register_obj = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, register_obj, str_ident=i,
|
||||
collection = cls(*args, **kwargs)
|
||||
entities = [cls._accepted_objects(tile, collection, str_ident=i,
|
||||
**entity_kwargs if entity_kwargs is not None else {})
|
||||
for i, tile in enumerate(tiles)]
|
||||
register_obj.register_additional_items(entities)
|
||||
return register_obj
|
||||
collection.add_additional_items(entities)
|
||||
return collection
|
||||
|
||||
@classmethod
|
||||
def from_argwhere_coordinates(cls, positions: [(int, int)], tiles, *args, entity_kwargs=None, **kwargs, ):
|
||||
@ -188,13 +193,13 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return [entity.tile for entity in self]
|
||||
|
||||
def __init__(self, level_shape, *args, **kwargs):
|
||||
super(EntityRegister, self).__init__(level_shape, *args, **kwargs)
|
||||
super(EntityCollection, self).__init__(level_shape, *args, **kwargs)
|
||||
self._lazy_eval_transforms = []
|
||||
|
||||
def __delitem__(self, name):
|
||||
idx, obj = next((i, obj) for i, obj in enumerate(self) if obj.name == name)
|
||||
obj.tile.leave(obj)
|
||||
super(EntityRegister, self).__delitem__(name)
|
||||
super(EntityCollection, self).__delitem__(name)
|
||||
|
||||
def as_array(self):
|
||||
if self._lazy_eval_transforms:
|
||||
@ -223,7 +228,7 @@ class EntityRegister(EnvObjectRegister, ABC):
|
||||
return None
|
||||
|
||||
|
||||
class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
class BoundEnvObjCollection(EnvObjectCollection, ABC):
|
||||
|
||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -248,13 +253,13 @@ class BoundEnvObjRegister(EnvObjectRegister, ABC):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
class MovingEntityObjectCollection(EntityCollection, ABC):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MovingEntityObjectRegister, self).__init__(*args, **kwargs)
|
||||
super(MovingEntityObjectCollection, self).__init__(*args, **kwargs)
|
||||
|
||||
def notify_change_to_value(self, entity):
|
||||
super(MovingEntityObjectRegister, self).notify_change_to_value(entity)
|
||||
super(MovingEntityObjectCollection, self).notify_change_to_value(entity)
|
||||
if entity.last_pos != c.NO_POS:
|
||||
try:
|
||||
self._array_change_notifyer(entity, entity.last_pos, value=c.FREE_CELL)
|
||||
@ -263,11 +268,11 @@ class MovingEntityObjectRegister(EntityRegister, ABC):
|
||||
|
||||
|
||||
##########################################################################
|
||||
# ################# Objects and Entity Registers ####################### #
|
||||
# ################# Objects and Entity Collection ###################### #
|
||||
##########################################################################
|
||||
|
||||
|
||||
class GlobalPositions(EnvObjectRegister):
|
||||
class GlobalPositions(EnvObjectCollection):
|
||||
|
||||
_accepted_objects = GlobalPosition
|
||||
|
||||
@ -288,10 +293,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||
for _, agent in enumerate(agents)]
|
||||
# noinspection PyTypeChecker
|
||||
self.register_additional_items(global_positions)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
return {}
|
||||
self.add_additional_items(global_positions)
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
@ -306,7 +308,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
return None
|
||||
|
||||
|
||||
class PlaceHolders(EnvObjectRegister):
|
||||
class PlaceHolders(EnvObjectCollection):
|
||||
_accepted_objects = PlaceHolder
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -320,12 +322,12 @@ class PlaceHolders(EnvObjectRegister):
|
||||
# objects_name = cls._accepted_objects.__name__
|
||||
if isinstance(values, (str, numbers.Number)):
|
||||
values = [values]
|
||||
register_obj = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(register_obj, str_ident=i, fill_value=value,
|
||||
collection = cls(*args, **kwargs)
|
||||
objects = [cls._accepted_objects(collection, str_ident=i, fill_value=value,
|
||||
**object_kwargs if object_kwargs is not None else {})
|
||||
for i, value in enumerate(values)]
|
||||
register_obj.register_additional_items(objects)
|
||||
return register_obj
|
||||
collection.add_additional_items(objects)
|
||||
return collection
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
@ -343,8 +345,8 @@ class PlaceHolders(EnvObjectRegister):
|
||||
return self._array
|
||||
|
||||
|
||||
class Entities(ObjectRegister):
|
||||
_accepted_objects = EntityRegister
|
||||
class Entities(ObjectCollection):
|
||||
_accepted_objects = EntityCollection
|
||||
|
||||
@property
|
||||
def arrays(self):
|
||||
@ -352,7 +354,7 @@ class Entities(ObjectRegister):
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return list(self._register.keys())
|
||||
return list(self._collection.keys())
|
||||
|
||||
def __init__(self):
|
||||
super(Entities, self).__init__()
|
||||
@ -360,22 +362,23 @@ class Entities(ObjectRegister):
|
||||
def iter_individual_entitites(self):
|
||||
return iter((x for sublist in self.values() for x in sublist))
|
||||
|
||||
def register_item(self, other: dict):
|
||||
def add_item(self, other: dict):
|
||||
assert not any([key for key in other.keys() if key in self.keys()]), \
|
||||
"This group of entities has already been registered!"
|
||||
self._register.update(other)
|
||||
"This group of entities has already been added!"
|
||||
self._collection.update(other)
|
||||
return self
|
||||
|
||||
def register_additional_items(self, others: Dict):
|
||||
return self.register_item(others)
|
||||
def add_additional_items(self, others: Dict):
|
||||
return self.add_item(others)
|
||||
|
||||
def by_pos(self, pos: (int, int)):
|
||||
found_entities = [y for y in (x.by_pos(pos) for x in self.values() if hasattr(x, 'by_pos')) if y is not None]
|
||||
return found_entities
|
||||
|
||||
|
||||
class Walls(EntityRegister):
|
||||
class Walls(EntityCollection):
|
||||
_accepted_objects = Wall
|
||||
_stateless_entities = True
|
||||
|
||||
def as_array(self):
|
||||
if not np.any(self._array):
|
||||
@ -396,7 +399,7 @@ class Walls(EntityRegister):
|
||||
def from_argwhere_coordinates(cls, argwhere_coordinates, *args, **kwargs):
|
||||
tiles = cls(*args, **kwargs)
|
||||
# noinspection PyTypeChecker
|
||||
tiles.register_additional_items(
|
||||
tiles.add_additional_items(
|
||||
[cls._accepted_objects(pos, tiles)
|
||||
for pos in argwhere_coordinates]
|
||||
)
|
||||
@ -406,15 +409,10 @@ class Walls(EntityRegister):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
if n_steps == h.STEPS_START:
|
||||
return super(Walls, self).summarize_states(n_steps=n_steps)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class Floors(Walls):
|
||||
_accepted_objects = Floor
|
||||
_stateless_entities = True
|
||||
|
||||
def __init__(self, *args, is_blocking_light=False, **kwargs):
|
||||
super(Floors, self).__init__(*args, is_blocking_light=is_blocking_light, **kwargs)
|
||||
@ -436,12 +434,8 @@ class Floors(Walls):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# Do not summarize
|
||||
return {}
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
class Agents(MovingEntityObjectCollection):
|
||||
_accepted_objects = Agent
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -455,10 +449,10 @@ class Agents(MovingEntityObjectRegister):
|
||||
old_agent = self[key]
|
||||
self[key].tile.leave(self[key])
|
||||
agent._name = old_agent.name
|
||||
self._register[agent.name] = agent
|
||||
self._collection[agent.name] = agent
|
||||
|
||||
|
||||
class Doors(EntityRegister):
|
||||
class Doors(EntityCollection):
|
||||
|
||||
def __init__(self, *args, have_area: bool = False, **kwargs):
|
||||
self.have_area = have_area
|
||||
@ -490,7 +484,7 @@ class Doors(EntityRegister):
|
||||
return super(Doors, self).as_array()
|
||||
|
||||
|
||||
class Actions(ObjectRegister):
|
||||
class Actions(ObjectCollection):
|
||||
_accepted_objects = Action
|
||||
|
||||
@property
|
||||
@ -507,22 +501,25 @@ class Actions(ObjectRegister):
|
||||
|
||||
# Move this to Baseclass, Env init?
|
||||
if self.allow_square_movement:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.square_move()])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.square_move()])
|
||||
if self.allow_diagonal_movement:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.diagonal_move()])
|
||||
self._movement_actions = self._register.copy()
|
||||
self.add_additional_items([self._accepted_objects(str_ident=direction)
|
||||
for direction in h.EnvActions.diagonal_move()])
|
||||
self._movement_actions = self._collection.copy()
|
||||
if self.can_use_doors:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.USE_DOOR)])
|
||||
if self.allow_no_op:
|
||||
self.register_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||
self.add_additional_items([self._accepted_objects(str_ident=h.EnvActions.NOOP)])
|
||||
|
||||
def is_moving_action(self, action: Union[int]):
|
||||
return action in self.movement_actions.values()
|
||||
|
||||
def summarize(self):
|
||||
return [dict(name=action.identifier) for action in self]
|
||||
|
||||
class Zones(ObjectRegister):
|
||||
|
||||
class Zones(ObjectCollection):
|
||||
|
||||
@property
|
||||
def accounting_zones(self):
|
||||
@ -551,5 +548,5 @@ class Zones(ObjectRegister):
|
||||
def __getitem__(self, item):
|
||||
return self._zone_slices[item]
|
||||
|
||||
def register_additional_items(self, other: Union[str, List[str]]):
|
||||
def add_additional_items(self, other: Union[str, List[str]]):
|
||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
||||
|
@ -20,21 +20,33 @@ class RenderEntity(NamedTuple):
|
||||
aux: Any = None
|
||||
|
||||
|
||||
class RenderNames:
|
||||
AGENT: str = 'agent'
|
||||
BLANK: str = 'blank'
|
||||
DOOR: str = 'door'
|
||||
OPACITY: str = 'opacity'
|
||||
SCALE: str = 'scale'
|
||||
rn = RenderNames
|
||||
|
||||
|
||||
class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||
|
||||
def __init__(self, grid_w=16, grid_h=16, cell_size=40, fps=7, grid_lines=True, view_radius=2):
|
||||
self.grid_h = grid_h
|
||||
self.grid_w = grid_w
|
||||
def __init__(self, lvl_shape=(16, 16),
|
||||
lvl_padded_shape=None,
|
||||
cell_size=40, fps=7,
|
||||
grid_lines=True, view_radius=2):
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
self.fps = fps
|
||||
self.grid_lines = grid_lines
|
||||
self.view_radius = view_radius
|
||||
pygame.init()
|
||||
self.screen_size = (grid_w*cell_size, grid_h*cell_size)
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
@ -43,7 +55,7 @@ class Renderer:
|
||||
|
||||
now = time.time()
|
||||
self.font = pygame.font.Font(None, 20)
|
||||
self.font.set_bold(1)
|
||||
self.font.set_bold(True)
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
@ -56,11 +68,16 @@ class Renderer:
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity):
|
||||
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
r, c = entity.pos
|
||||
r, c = r - offset_r, c-offset_c
|
||||
|
||||
img = self.assets[entity.name.lower()]
|
||||
if entity.value_operation == 'opacity':
|
||||
if entity.value_operation == rn.OPACITY:
|
||||
img.set_alpha(255*entity.value)
|
||||
elif entity.value_operation == 'scale':
|
||||
elif entity.value_operation == rn.SCALE:
|
||||
re = img.get_rect()
|
||||
img = pygame.transform.smoothscale(
|
||||
img, (int(entity.value*re.width), int(entity.value*re.height))
|
||||
@ -99,19 +116,19 @@ class Renderer:
|
||||
sys.exit()
|
||||
self.fill_bg()
|
||||
blits = deque()
|
||||
for entity in [x for x in entities if 'door' in x.name]:
|
||||
for entity in [x for x in entities if rn.DOOR in x.name]:
|
||||
bp = self.blit_params(entity)
|
||||
blits.append(bp)
|
||||
for entity in [x for x in entities if 'door' not in x.name]:
|
||||
for entity in [x for x in entities if rn.DOOR not in x.name]:
|
||||
bp = self.blit_params(entity)
|
||||
blits.append(bp)
|
||||
if entity.name.lower() == 'agent':
|
||||
if entity.name.lower() == rn.AGENT:
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if entity.state != 'blank':
|
||||
if entity.state != rn.BLANK:
|
||||
agent_state_blits = self.blit_params(
|
||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, 'scale')
|
||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, rn.SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
||||
|
@ -1,273 +0,0 @@
|
||||
from typing import Union, NamedTuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.factory.base.objects import Agent, Action, Entity, EnvObject, BoundingMixin
|
||||
from environments.factory.base.registers import EntityRegister, EnvObjectRegister
|
||||
from environments.factory.base.renderer import RenderEntity
|
||||
from environments.helpers import Constants as BaseConstants
|
||||
from environments.helpers import EnvActions as BaseActions
|
||||
|
||||
from environments import helpers as h
|
||||
|
||||
|
||||
class Constants(BaseConstants):
|
||||
# Battery Env
|
||||
CHARGE_PODS = 'Charge_Pod'
|
||||
BATTERIES = 'BATTERIES'
|
||||
BATTERY_DISCHARGED = 'DISCHARGED'
|
||||
CHARGE_POD = 1
|
||||
|
||||
|
||||
class Actions(BaseActions):
|
||||
CHARGE = 'do_charge_action'
|
||||
|
||||
|
||||
class RewardsBtry(NamedTuple):
|
||||
CHARGE_VALID: float = 0.1
|
||||
CHARGE_FAIL: float = -0.1
|
||||
BATTERY_DISCHARGED: float = -1.0
|
||||
|
||||
|
||||
class BatteryProperties(NamedTuple):
|
||||
initial_charge: float = 0.8 #
|
||||
charge_rate: float = 0.4 #
|
||||
charge_locations: int = 20 #
|
||||
per_action_costs: Union[dict, float] = 0.02
|
||||
done_when_discharged = False
|
||||
multi_charge: bool = False
|
||||
|
||||
|
||||
c = Constants
|
||||
a = Actions
|
||||
|
||||
|
||||
class Battery(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def is_discharged(self):
|
||||
return self.charge_level == 0
|
||||
|
||||
def __init__(self, initial_charge_level: float, *args, **kwargs):
|
||||
super(Battery, self).__init__(*args, **kwargs)
|
||||
self.charge_level = initial_charge_level
|
||||
|
||||
def encoding(self):
|
||||
return self.charge_level
|
||||
|
||||
def do_charge_action(self, amount):
|
||||
if self.charge_level < 1:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = min(1, amount + self.charge_level)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def decharge(self, amount) -> c:
|
||||
if self.charge_level != 0:
|
||||
# noinspection PyTypeChecker
|
||||
self.charge_level = max(0, amount + self.charge_level)
|
||||
self._register.notify_change_to_value(self)
|
||||
return c.VALID
|
||||
else:
|
||||
return c.NOT_VALID
|
||||
|
||||
def summarize_state(self, **_):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update(dict(name=self.name))
|
||||
return attr_dict
|
||||
|
||||
|
||||
class BatteriesRegister(EnvObjectRegister):
|
||||
|
||||
_accepted_objects = Battery
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatteriesRegister, self).__init__(*args, individual_slices=True,
|
||||
is_blocking_light=False, can_be_shadowed=False, **kwargs)
|
||||
self.is_observable = True
|
||||
|
||||
def spawn_batteries(self, agents, initial_charge_level):
|
||||
batteries = [self._accepted_objects(initial_charge_level, agent, self) for _, agent in enumerate(agents)]
|
||||
self.register_additional_items(batteries)
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# as dict with additional nesting
|
||||
# return dict(items=super(Inventories, cls).summarize_states())
|
||||
return super(BatteriesRegister, self).summarize_states(n_steps=n_steps)
|
||||
|
||||
# Todo Move this to Mixin!
|
||||
def by_entity(self, entity):
|
||||
try:
|
||||
return next((x for x in self if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def idx_by_entity(self, entity):
|
||||
try:
|
||||
return next((idx for idx, x in enumerate(self) if x.belongs_to_entity(entity)))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def as_array_by_entity(self, entity):
|
||||
return self._array[self.idx_by_entity(entity)]
|
||||
|
||||
|
||||
class ChargePod(Entity):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.CHARGE_POD
|
||||
|
||||
def __init__(self, *args, charge_rate: float = 0.4,
|
||||
multi_charge: bool = False, **kwargs):
|
||||
super(ChargePod, self).__init__(*args, **kwargs)
|
||||
self.charge_rate = charge_rate
|
||||
self.multi_charge = multi_charge
|
||||
|
||||
def charge_battery(self, battery: Battery):
|
||||
if battery.charge_level == 1.0:
|
||||
return c.NOT_VALID
|
||||
if sum(guest for guest in self.tile.guests if 'agent' in guest.name.lower()) > 1:
|
||||
return c.NOT_VALID
|
||||
valid = battery.do_charge_action(self.charge_rate)
|
||||
return valid
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
summary = super().summarize_state(n_steps=n_steps)
|
||||
return summary
|
||||
|
||||
|
||||
class ChargePods(EntityRegister):
|
||||
|
||||
_accepted_objects = ChargePod
|
||||
|
||||
def __repr__(self):
|
||||
super(ChargePods, self).__repr__()
|
||||
|
||||
|
||||
class BatteryFactory(BaseFactory):
|
||||
|
||||
def __init__(self, *args, btry_prop=BatteryProperties(), rewards_dest: RewardsBtry = RewardsBtry(),
|
||||
**kwargs):
|
||||
if isinstance(btry_prop, dict):
|
||||
btry_prop = BatteryProperties(**btry_prop)
|
||||
if isinstance(rewards_dest, dict):
|
||||
rewards_dest = BatteryProperties(**rewards_dest)
|
||||
self.btry_prop = btry_prop
|
||||
self.rewards_dest = rewards_dest
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def per_agent_raw_observations_hook(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
additional_raw_observations.update({c.BATTERIES: self[c.BATTERIES].as_array_by_entity(agent)})
|
||||
return additional_raw_observations
|
||||
|
||||
def observations_hook(self) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_observations = super().observations_hook()
|
||||
additional_observations.update({c.CHARGE_PODS: self[c.CHARGE_PODS].as_array()})
|
||||
return additional_observations
|
||||
|
||||
@property
|
||||
def entities_hook(self):
|
||||
super_entities = super().entities_hook
|
||||
|
||||
empty_tiles = self[c.FLOOR].empty_tiles[:self.btry_prop.charge_locations]
|
||||
charge_pods = ChargePods.from_tiles(
|
||||
empty_tiles, self._level_shape,
|
||||
entity_kwargs=dict(charge_rate=self.btry_prop.charge_rate,
|
||||
multi_charge=self.btry_prop.multi_charge)
|
||||
)
|
||||
|
||||
batteries = BatteriesRegister(self._level_shape if not self._pomdp_r else ((self.pomdp_diameter,) * 2),
|
||||
)
|
||||
batteries.spawn_batteries(self[c.AGENT], self.btry_prop.initial_charge)
|
||||
super_entities.update({c.BATTERIES: batteries, c.CHARGE_PODS: charge_pods})
|
||||
return super_entities
|
||||
|
||||
def step_hook(self) -> (List[dict], dict):
|
||||
super_reward_info = super(BatteryFactory, self).step_hook()
|
||||
|
||||
# Decharge
|
||||
batteries = self[c.BATTERIES]
|
||||
|
||||
for agent in self[c.AGENT]:
|
||||
if isinstance(self.btry_prop.per_action_costs, dict):
|
||||
energy_consumption = self.btry_prop.per_action_costs[agent.temp_action]
|
||||
else:
|
||||
energy_consumption = self.btry_prop.per_action_costs
|
||||
|
||||
batteries.by_entity(agent).decharge(energy_consumption)
|
||||
|
||||
return super_reward_info
|
||||
|
||||
def do_charge_action(self, agent) -> (dict, dict):
|
||||
if charge_pod := self[c.CHARGE_PODS].by_pos(agent.pos):
|
||||
valid = charge_pod.charge_battery(self[c.BATTERIES].by_entity(agent))
|
||||
if valid:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_VALID': 1}
|
||||
self.print(f'{agent.name} just charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {charge_pod.name}.')
|
||||
else:
|
||||
valid = c.NOT_VALID
|
||||
info_dict = {f'{agent.name}_{a.CHARGE}_FAIL': 1}
|
||||
# info_dict = {f'{agent.name}_no_charger': 1}
|
||||
self.print(f'{agent.name} failed to charged batteries at {agent.pos}.')
|
||||
reward = dict(value=self.rewards_dest.CHARGE_VALID if valid else self.rewards_dest.CHARGE_FAIL,
|
||||
reason=a.CHARGE, info=info_dict)
|
||||
return valid, reward
|
||||
|
||||
def do_additional_actions(self, agent: Agent, action: Action) -> (bool, dict):
|
||||
action_result = super().do_additional_actions(agent, action)
|
||||
if action_result is None:
|
||||
if action == a.CHARGE:
|
||||
action_result = self.do_charge_action(agent)
|
||||
return action_result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return action_result
|
||||
pass
|
||||
|
||||
def reset_hook(self) -> None:
|
||||
# There is Nothing to reset.
|
||||
pass
|
||||
|
||||
def check_additional_done(self) -> (bool, dict):
|
||||
super_done, super_dict = super(BatteryFactory, self).check_additional_done()
|
||||
if super_done:
|
||||
return super_done, super_dict
|
||||
else:
|
||||
if self.btry_prop.done_when_discharged:
|
||||
if btry_done := any(battery.is_discharged for battery in self[c.BATTERIES]):
|
||||
super_dict.update(DISCHARGE_DONE=1)
|
||||
return btry_done, super_dict
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
pass
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
reward_event_dict = super(BatteryFactory, self).per_agent_reward_hook(agent)
|
||||
if self[c.BATTERIES].by_entity(agent).is_discharged:
|
||||
self.print(f'{agent.name} Battery is discharged!')
|
||||
info_dict = {f'{agent.name}_{c.BATTERY_DISCHARGED}': 1}
|
||||
reward_event_dict.update({c.BATTERY_DISCHARGED: {'reward': self.rewards_dest.BATTERY_DISCHARGED,
|
||||
'info': info_dict}}
|
||||
)
|
||||
else:
|
||||
# All Fine
|
||||
pass
|
||||
return reward_event_dict
|
||||
|
||||
def render_assets_hook(self):
|
||||
# noinspection PyUnresolvedReferences
|
||||
additional_assets = super().render_assets_hook()
|
||||
charge_pods = [RenderEntity(c.CHARGE_PODS, charge_pod.tile.pos) for charge_pod in self[c.CHARGE_PODS]]
|
||||
additional_assets.extend(charge_pods)
|
||||
return additional_assets
|
@ -3,12 +3,14 @@ from typing import Dict, List, Union
|
||||
import numpy as np
|
||||
|
||||
from environments.factory.base.objects import Agent, Entity, Action
|
||||
from environments.factory.factory_dirt import Dirt, DirtRegister, DirtFactory
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.additional.dirt.dirt_collections import DirtPiles
|
||||
from environments.factory.additional.dirt.dirt_entity import DirtPile
|
||||
from environments.factory.base.objects import Floor
|
||||
from environments.factory.base.registers import Floors, Entities, EntityRegister
|
||||
from environments.factory.base.registers import Floors, Entities, EntityCollection
|
||||
|
||||
|
||||
class Machines(EntityRegister):
|
||||
class Machines(EntityCollection):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -28,7 +30,6 @@ class StationaryMachinesDirtFactory(DirtFactory):
|
||||
|
||||
def entities_hook(self) -> Dict[(str, Entities)]:
|
||||
super_entities = super().entities_hook()
|
||||
|
||||
return super_entities
|
||||
|
||||
def reset_hook(self) -> None:
|
||||
@ -48,8 +49,8 @@ class StationaryMachinesDirtFactory(DirtFactory):
|
||||
super_per_agent_raw_observations = super().per_agent_raw_observations_hook(agent)
|
||||
return super_per_agent_raw_observations
|
||||
|
||||
def per_agent_reward_hook(self, agent: Agent) -> Dict[str, dict]:
|
||||
pass
|
||||
def per_agent_reward_hook(self, agent: Agent) -> List[dict]:
|
||||
return super(StationaryMachinesDirtFactory, self).per_agent_reward_hook(agent)
|
||||
|
||||
def pre_step_hook(self) -> None:
|
||||
pass
|
||||
|
@ -7,47 +7,76 @@ import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
|
||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||
|
||||
LEVELS_DIR = 'levels'
|
||||
STEPS_START = 1
|
||||
|
||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
||||
'dirty_tile_count', 'terminal_observation', 'episode']
|
||||
"""
|
||||
This file is used for:
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
|
||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C) # For use in studies and experiments
|
||||
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
# Not used anymore? Clean!
|
||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
|
||||
# Constants
|
||||
class Constants:
|
||||
WALL = '#'
|
||||
WALLS = 'Walls'
|
||||
FLOOR = 'Floor'
|
||||
DOOR = 'D'
|
||||
DANGER_ZONE = 'x'
|
||||
LEVEL = 'Level'
|
||||
AGENT = 'Agent'
|
||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION'
|
||||
FREE_CELL = 0
|
||||
OCCUPIED_CELL = 1
|
||||
SHADOWED_CELL = -1
|
||||
ACCESS_DOOR_CELL = 1/3
|
||||
OPEN_DOOR_CELL = 2/3
|
||||
CLOSED_DOOR_CELL = 3/3
|
||||
NO_POS = (-9999, -9999)
|
||||
|
||||
DOORS = 'Doors'
|
||||
CLOSED_DOOR = 'closed'
|
||||
OPEN_DOOR = 'open'
|
||||
ACCESS_DOOR = 'access'
|
||||
"""
|
||||
String based mapping. Use these to handle keys or define values, which can be then be used globaly.
|
||||
Please use class inheritance when defining new environments.
|
||||
"""
|
||||
|
||||
ACTION = 'action'
|
||||
COLLISION = 'collision'
|
||||
VALID = True
|
||||
NOT_VALID = False
|
||||
WALL = '#' # Wall tile identifier for resolving the string based map files.
|
||||
DOOR = 'D' # Door identifier for resolving the string based map files.
|
||||
DANGER_ZONE = 'x' # Dange Zone tile identifier for resolving the string based map files.
|
||||
|
||||
WALLS = 'Walls' # Identifier of Wall-objects and sets (collections).
|
||||
FLOOR = 'Floor' # Identifier of Floor-objects and sets (collections).
|
||||
DOORS = 'Doors' # Identifier of Door-objects and sets (collections).
|
||||
LEVEL = 'Level' # Identifier of Level-objects and sets (collections).
|
||||
AGENT = 'Agent' # Identifier of Agent-objects and sets (collections).
|
||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER' # Identifier of Placeholder-objects and sets (collections).
|
||||
GLOBAL_POSITION = 'GLOBAL_POSITION' # Identifier of the global position slice
|
||||
|
||||
FREE_CELL = 0 # Free-Cell value used in observation
|
||||
OCCUPIED_CELL = 1 # Occupied-Cell value used in observation
|
||||
SHADOWED_CELL = -1 # Shadowed-Cell value used in observation
|
||||
ACCESS_DOOR_CELL = 1/3 # Access-door-Cell value used in observation
|
||||
OPEN_DOOR_CELL = 2/3 # Open-door-Cell value used in observation
|
||||
CLOSED_DOOR_CELL = 3/3 # Closed-door-Cell value used in observation
|
||||
|
||||
NO_POS = (-9999, -9999) # Invalid Position value used in the environment (something is off-grid)
|
||||
|
||||
CLOSED_DOOR = 'closed' # Identifier to compare door-is-closed state
|
||||
OPEN_DOOR = 'open' # Identifier to compare door-is-open state
|
||||
# ACCESS_DOOR = 'access' # Identifier to compare access positions
|
||||
|
||||
ACTION = 'action' # Identifier of Action-objects and sets (collections).
|
||||
COLLISION = 'collision' # Identifier to use in the context of collitions.
|
||||
VALID = True # Identifier to rename boolean values in the context of actions.
|
||||
NOT_VALID = False # Identifier to rename boolean values in the context of actions.
|
||||
|
||||
|
||||
class EnvActions:
|
||||
"""
|
||||
String based mapping. Use these to identifiy actions, can be used globaly.
|
||||
Please use class inheritance when defining new environments with new actions.
|
||||
"""
|
||||
# Movements
|
||||
NORTH = 'north'
|
||||
EAST = 'east'
|
||||
@ -63,24 +92,77 @@ class EnvActions:
|
||||
NOOP = 'no_op'
|
||||
USE_DOOR = 'use_door'
|
||||
|
||||
_ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||
{NORTH: (-1, 0), NORTHEAST: (-1, 1),
|
||||
EAST: (0, 1), SOUTHEAST: (1, 1),
|
||||
SOUTH: (1, 0), SOUTHWEST: (1, -1),
|
||||
WEST: (0, -1), NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_move(cls, other):
|
||||
return any([other == direction for direction in cls.movement_actions()])
|
||||
def is_move(cls, action):
|
||||
"""
|
||||
Classmethod; checks if given action is a movement action or not. Depending on the env. configuration,
|
||||
Movement actions are either `manhattan` (square) style movements (up,down, left, right) and/or diagonal.
|
||||
|
||||
:param action: Action to be checked
|
||||
:type action: str
|
||||
:return: Whether the given action is a movement action.
|
||||
:rtype: bool
|
||||
"""
|
||||
return any([action == direction for direction in cls.movement_actions()])
|
||||
|
||||
@classmethod
|
||||
def square_move(cls):
|
||||
"""
|
||||
Classmethod; return a list of movement actions that are considered square or `manhattan` style movements.
|
||||
|
||||
:return: A list of movement actions.
|
||||
:rtype: list(str)
|
||||
"""
|
||||
return [cls.NORTH, cls.EAST, cls.SOUTH, cls.WEST]
|
||||
|
||||
@classmethod
|
||||
def diagonal_move(cls):
|
||||
"""
|
||||
Classmethod; return a list of movement actions that are considered diagonal movements.
|
||||
|
||||
:return: A list of movement actions.
|
||||
:rtype: list(str)
|
||||
"""
|
||||
return [cls.NORTHEAST, cls.SOUTHEAST, cls.SOUTHWEST, cls.NORTHWEST]
|
||||
|
||||
@classmethod
|
||||
def movement_actions(cls):
|
||||
"""
|
||||
Classmethod; return a list of all available movement actions.
|
||||
Please note, that this is indipendent from the env. properties
|
||||
|
||||
:return: A list of movement actions.
|
||||
:rtype: list(str)
|
||||
"""
|
||||
return list(itertools.chain(cls.square_move(), cls.diagonal_move()))
|
||||
|
||||
@classmethod
|
||||
def resolve_movement_action_to_coords(cls, action):
|
||||
"""
|
||||
Classmethod; resolve movement actions. Given a movement action, return the delta in coordinates it stands for.
|
||||
How does the current entity coordinate change if it performs the given action?
|
||||
Please note, this is indipendent from the env. properties
|
||||
|
||||
:return: Delta coorinates.
|
||||
:rtype: tuple(int, int)
|
||||
"""
|
||||
return cls._ACTIONMAP[action]
|
||||
|
||||
|
||||
class RewardsBase(NamedTuple):
|
||||
"""
|
||||
Value based mapping. Use these to define reward values for specific conditions (i.e. the action
|
||||
in a given context), can be used globaly.
|
||||
Please use class inheritance when defining new environments with new rewards.
|
||||
"""
|
||||
MOVEMENTS_VALID: float = -0.001
|
||||
MOVEMENTS_FAIL: float = -0.05
|
||||
NOOP: float = -0.01
|
||||
@ -89,43 +171,61 @@ class RewardsBase(NamedTuple):
|
||||
COLLISION: float = -0.5
|
||||
|
||||
|
||||
m = EnvActions
|
||||
c = Constants
|
||||
|
||||
ACTIONMAP = defaultdict(lambda: (0, 0),
|
||||
{m.NORTH: (-1, 0), m.NORTHEAST: (-1, 1),
|
||||
m.EAST: (0, 1), m.SOUTHEAST: (1, 1),
|
||||
m.SOUTH: (1, 0), m.SOUTHWEST: (1, -1),
|
||||
m.WEST: (0, -1), m.NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ObservationTranslator:
|
||||
|
||||
def __init__(self, obs_shape_2d: (int, int), this_named_observation_space: Dict[str, dict],
|
||||
*per_agent_named_obs_space: Dict[str, dict],
|
||||
placeholder_fill_value: Union[int, str] = 'N'):
|
||||
assert len(obs_shape_2d) == 2
|
||||
self.obs_shape = obs_shape_2d
|
||||
def __init__(self, this_named_observation_space: Dict[str, dict],
|
||||
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||
placeholder_fill_value: Union[int, str, None] = None):
|
||||
"""
|
||||
This is a helper class, which converts agents observations from joined environments.
|
||||
For example, agents trained in different environments may expect different observations.
|
||||
This class translates from larger observations spaces to smaller.
|
||||
A string identifier based approach is used.
|
||||
Currently, it is not possible to mix different obs shapes.
|
||||
|
||||
|
||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||
:type this_named_observation_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||
self.random_fill = lambda: np.random.normal(size=self.obs_shape)
|
||||
self.random_fill = np.random.normal
|
||||
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||
self.random_fill = lambda: np.random.uniform(size=self.obs_shape)
|
||||
self.random_fill = np.random.uniform
|
||||
else:
|
||||
raise ValueError('Please chooe between "uniform" or "normal"')
|
||||
raise ValueError('Please chooe between "uniform" or "normal" ("u", "n").')
|
||||
elif isinstance(placeholder_fill_value, int):
|
||||
raise NotImplementedError('"Future Work."')
|
||||
else:
|
||||
self.random_fill = None
|
||||
|
||||
self._this_named_obs_space = this_named_observation_space
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_space)
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||
|
||||
def translate_observation(self, agent_idx: int, obs: np.ndarray):
|
||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||
translation = [idx_space_dict for name, idx_space_dict in target_obs_space.items()]
|
||||
flat_translation = [x for y in translation for x in y]
|
||||
return np.take(obs, flat_translation, axis=1 if obs.ndim == 4 else 0)
|
||||
translation = dict()
|
||||
for name, idxs in target_obs_space.items():
|
||||
if name in self._this_named_obs_space:
|
||||
for target_idx, this_idx in zip(idxs, self._this_named_obs_space[name]):
|
||||
taken_slice = np.take(obs, [this_idx], axis=1 if obs.ndim == 4 else 0)
|
||||
translation[target_idx] = taken_slice
|
||||
elif random_fill := self.random_fill:
|
||||
for target_idx in idxs:
|
||||
translation[target_idx] = random_fill(size=obs.shape[:-3] + (1,) + obs.shape[-2:])
|
||||
else:
|
||||
for target_idx in idxs:
|
||||
translation[target_idx] = np.zeros(shape=(obs.shape[:-3] + (1,) + obs.shape[-2:]))
|
||||
|
||||
translation = dict(sorted(translation.items()))
|
||||
return np.concatenate(list(translation.values()), axis=-3)
|
||||
|
||||
def translate_observations(self, observations: List[ArrayLike]):
|
||||
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||
@ -137,8 +237,24 @@ class ObservationTranslator:
|
||||
class ActionTranslator:
|
||||
|
||||
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||
"""
|
||||
This is a helper class, which converts agents action spaces to a joined environments action space.
|
||||
For example, agents trained in different environments may have different action spaces.
|
||||
This class translates from smaller individual agent action spaces to larger joined spaces.
|
||||
A string identifier based approach is used.
|
||||
|
||||
:param target_named_action_space: Joined `Named action space` for the current environment.
|
||||
:type target_named_action_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
||||
:type per_agent_named_action_space: Dict[str, dict]
|
||||
"""
|
||||
|
||||
self._target_named_action_space = target_named_action_space
|
||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||
if isinstance(per_agent_named_action_space, (list, tuple)):
|
||||
self._per_agent_named_action_space = per_agent_named_action_space
|
||||
else:
|
||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||
|
||||
def translate_action(self, agent_idx: int, action: int):
|
||||
@ -155,6 +271,16 @@ class ActionTranslator:
|
||||
|
||||
# Utility functions
|
||||
def parse_level(path):
|
||||
"""
|
||||
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
||||
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
||||
|
||||
:param path: Path to the `level` or `map` file on harddrive.
|
||||
:type path: os.Pathlike
|
||||
|
||||
:return: The read string representation of the `level` or `map`
|
||||
:rtype: List[List[str]]
|
||||
"""
|
||||
with path.open('r') as lvl:
|
||||
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
||||
if len(set([len(line) for line in level])) > 1:
|
||||
@ -162,29 +288,56 @@ def parse_level(path):
|
||||
return level
|
||||
|
||||
|
||||
def one_hot_level(level, wall_char: str = c.WALL):
|
||||
def one_hot_level(level, wall_char: str = Constants.WALL):
|
||||
"""
|
||||
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
||||
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
||||
Can be changed to filter for any symbol.
|
||||
|
||||
:param level: String based level representation (list of lists, see function `parse_level`).
|
||||
:param wall_char: List[List[str]]
|
||||
|
||||
:return: Binary numpy array
|
||||
:rtype: np.typing._array_like.ArrayLike
|
||||
"""
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == wall_char] = c.OCCUPIED_CELL
|
||||
binary_grid[grid == wall_char] = Constants.OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
def check_position(slice_to_check_against: ArrayLike, position_to_check: Tuple[int, int]):
|
||||
"""
|
||||
Given a slice (2-D Arraylike object)
|
||||
|
||||
:param slice_to_check_against: The slice to check for accessability
|
||||
:type slice_to_check_against: np.typing._array_like.ArrayLike
|
||||
|
||||
:param position_to_check: Position in slice that should be checked. Can be outside of slice boundarys.
|
||||
:type position_to_check: tuple(int, int)
|
||||
|
||||
:return: Whether a position can be moved to.
|
||||
:rtype: bool
|
||||
"""
|
||||
x_pos, y_pos = position_to_check
|
||||
|
||||
# Check if agent colides with grid boundrys
|
||||
valid = not (
|
||||
x_pos < 0 or y_pos < 0
|
||||
or x_pos >= slice_to_check_against.shape[0]
|
||||
or y_pos >= slice_to_check_against.shape[0]
|
||||
or y_pos >= slice_to_check_against.shape[1]
|
||||
)
|
||||
|
||||
# Check for collision with level walls
|
||||
valid = valid and not slice_to_check_against[x_pos, y_pos]
|
||||
return c.VALID if valid else c.NOT_VALID
|
||||
return Constants.VALID if valid else Constants.NOT_VALID
|
||||
|
||||
|
||||
def asset_str(agent):
|
||||
"""
|
||||
FIXME @ romue
|
||||
"""
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
@ -192,33 +345,50 @@ def asset_str(agent):
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
if any(Constants.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif not valid or Constants.LEVEL in col_names or Constants.AGENT in col_names:
|
||||
return Constants.AGENT, 'invalid'
|
||||
elif valid and not EnvActions.is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
return Constants.AGENT, 'valid'
|
||||
elif valid and EnvActions.is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
return Constants.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
return Constants.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
return Constants.AGENT, 'idle'
|
||||
|
||||
|
||||
def points_to_graph(coordiniates_or_tiles, allow_euclidean_connections=True, allow_manhattan_connections=True):
|
||||
"""
|
||||
Given a set of coordinates, this function contructs a non-directed graph, by conncting adjected points.
|
||||
There are three combinations of settings:
|
||||
Allow all neigbors: Distance(a, b) <= sqrt(2)
|
||||
Allow only manhattan: Distance(a, b) == 1
|
||||
Allow only euclidean: Distance(a, b) == sqrt(2)
|
||||
|
||||
|
||||
:param coordiniates_or_tiles: A set of coordinates.
|
||||
:type coordiniates_or_tiles: Tiles
|
||||
:param allow_euclidean_connections: Whether to regard diagonal adjected cells as neighbors
|
||||
:type: bool
|
||||
:param allow_manhattan_connections: Whether to regard directly adjected cells as neighbors
|
||||
:type: bool
|
||||
|
||||
:return: A graph with nodes that are conneceted as specified by the parameters.
|
||||
:rtype: nx.Graph
|
||||
"""
|
||||
assert allow_euclidean_connections or allow_manhattan_connections
|
||||
if hasattr(coordiniates_or_tiles, 'positions'):
|
||||
coordiniates_or_tiles = coordiniates_or_tiles.positions
|
||||
possible_connections = itertools.combinations(coordiniates_or_tiles, 2)
|
||||
graph = nx.Graph()
|
||||
for a, b in possible_connections:
|
||||
diff = abs(np.subtract(a, b))
|
||||
if not max(diff) > 1:
|
||||
if allow_manhattan_connections and allow_euclidean_connections:
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and all(diff):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and not all(diff) and any(diff):
|
||||
graph.add_edge(a, b)
|
||||
diff = np.linalg.norm(np.asarray(a)-np.asarray(b))
|
||||
if allow_manhattan_connections and allow_euclidean_connections and diff <= np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif not allow_manhattan_connections and allow_euclidean_connections and diff == np.sqrt(2):
|
||||
graph.add_edge(a, b)
|
||||
elif allow_manhattan_connections and not allow_euclidean_connections and diff == 1:
|
||||
graph.add_edge(a, b)
|
||||
return graph
|
||||
|
@ -47,7 +47,7 @@ class EnvMonitor(BaseCallback):
|
||||
self._read_info(env_idx, info)
|
||||
|
||||
for env_idx, done in list(
|
||||
enumerate(self.locals.get('dones', []))): # + list(enumerate(self.locals.get('done', []))):
|
||||
enumerate(self.locals.get('dones', []))) + list(enumerate(self.locals.get('done', []))):
|
||||
self._read_done(env_idx, done)
|
||||
return True
|
||||
|
||||
@ -71,8 +71,8 @@ class EnvMonitor(BaseCallback):
|
||||
pass
|
||||
return
|
||||
|
||||
def save_run(self, filepath: Union[Path, str], auto_plotting_keys=None):
|
||||
filepath = Path(filepath)
|
||||
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
@ -1,10 +1,13 @@
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import simplejson
|
||||
from deepdiff.operator import BaseOperator
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
from environments.factory.base.base_factory import REC_TAC
|
||||
@ -12,11 +15,15 @@ from environments.factory.base.base_factory import REC_TAC
|
||||
|
||||
class EnvRecorder(BaseCallback):
|
||||
|
||||
def __init__(self, env, entities='all'):
|
||||
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
|
||||
super(EnvRecorder, self).__init__()
|
||||
self.filepath = filepath
|
||||
self.unwrapped = env
|
||||
self.freq = freq
|
||||
self._recorder_dict = defaultdict(list)
|
||||
self._recorder_out_list = list()
|
||||
self._episode_counter = 1
|
||||
self._do_record_dict = defaultdict(lambda: False)
|
||||
if isinstance(entities, str):
|
||||
if entities.lower() == 'all':
|
||||
self._entities = None
|
||||
@ -29,46 +36,65 @@ class EnvRecorder(BaseCallback):
|
||||
return getattr(self.unwrapped, item)
|
||||
|
||||
def reset(self):
|
||||
self.unwrapped.start_recording()
|
||||
self._on_training_start()
|
||||
return self.unwrapped.reset()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
self.unwrapped._record_episodes = True
|
||||
pass
|
||||
assert self.start_recording()
|
||||
|
||||
def _read_info(self, env_idx, info: dict):
|
||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||
if self._entities:
|
||||
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
||||
|
||||
info_dict.update(episode=(self.num_timesteps + env_idx))
|
||||
self._recorder_dict[env_idx].append(info_dict)
|
||||
else:
|
||||
pass
|
||||
return
|
||||
return True
|
||||
|
||||
def _read_done(self, env_idx, done):
|
||||
if done:
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||
'episode': len(self._recorder_out_list)})
|
||||
'episode': self._episode_counter})
|
||||
self._recorder_dict[env_idx] = list()
|
||||
else:
|
||||
pass
|
||||
|
||||
def step(self, actions):
|
||||
step_result = self.unwrapped.step(actions)
|
||||
# 0, 1, 2 , 3 = idx
|
||||
# _, _, done_bool, info_obj = step_result
|
||||
self._read_info(0, step_result[3])
|
||||
self._read_done(0, step_result[2])
|
||||
if self.do_record_episode(0):
|
||||
info = step_result[-1]
|
||||
self._read_info(0, info)
|
||||
if self._do_record_dict[0]:
|
||||
self._read_done(0, step_result[-2])
|
||||
return step_result
|
||||
|
||||
def save_records(self, filepath: Union[Path, str], save_occupation_map=False, save_trajectory_map=False):
|
||||
filepath = Path(filepath)
|
||||
def finalize(self):
|
||||
self._on_training_end()
|
||||
return True
|
||||
|
||||
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||
only_deltas=True,
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
# cls.out_file.unlink(missing_ok=True)
|
||||
with filepath.open('w') as f:
|
||||
out_dict = {'episodes': self._recorder_out_list, 'header': self.unwrapped.params}
|
||||
if only_deltas:
|
||||
from deepdiff import DeepDiff, Delta
|
||||
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
||||
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||
]
|
||||
out_dict = {'episodes': diff_dict}
|
||||
|
||||
else:
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
out_dict.update(
|
||||
{'n_episodes': self._episode_counter,
|
||||
'env_params': self.unwrapped.params,
|
||||
'header': self.unwrapped.summarize_header
|
||||
})
|
||||
try:
|
||||
simplejson.dump(out_dict, f, indent=4)
|
||||
except TypeError:
|
||||
@ -76,6 +102,7 @@ class EnvRecorder(BaseCallback):
|
||||
|
||||
if save_occupation_map:
|
||||
a = np.zeros((15, 15))
|
||||
# noinspection PyTypeChecker
|
||||
for episode in out_dict['episodes']:
|
||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||
|
||||
@ -93,16 +120,34 @@ class EnvRecorder(BaseCallback):
|
||||
if save_trajectory_map:
|
||||
raise NotImplementedError('This has not yet been implemented.')
|
||||
|
||||
def do_record_episode(self, env_idx):
|
||||
if not self._recorder_dict[env_idx]:
|
||||
if self.freq:
|
||||
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
|
||||
else:
|
||||
self._do_record_dict[env_idx] = False
|
||||
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
|
||||
'Nothing will be recorded')
|
||||
self._episode_counter += 1
|
||||
else:
|
||||
pass
|
||||
return self._do_record_dict[env_idx]
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||
self._read_info(env_idx, info)
|
||||
|
||||
if self._do_record_dict[env_idx]:
|
||||
self._read_info(env_idx, info)
|
||||
dones = list(enumerate(self.locals.get('dones', [])))
|
||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
||||
for env_idx, done in dones:
|
||||
self._read_done(env_idx, done)
|
||||
if self._do_record_dict[env_idx]:
|
||||
self._read_done(env_idx, done)
|
||||
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
for env_idx in range(len(self._recorder_dict)):
|
||||
if self._recorder_dict[env_idx]:
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
||||
'episode': self._episode_counter})
|
||||
pass
|
||||
|
@ -3,7 +3,38 @@ import gym
|
||||
from gym.wrappers.frame_stack import FrameStack
|
||||
|
||||
|
||||
class EnvCombiner(object):
|
||||
|
||||
def __init__(self, *envs_cls):
|
||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||
|
||||
@staticmethod
|
||||
def combine_cls(name, *envs_cls):
|
||||
return type(name,envs_cls,{})
|
||||
|
||||
def build(self):
|
||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||
|
||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||
|
||||
|
||||
class AgentRenderOptions(object):
|
||||
"""
|
||||
Class that specifies the available options for the way agents are represented in the env observation.
|
||||
|
||||
SEPERATE:
|
||||
Each agent is represented in a seperate slice as Constant.OCCUPIED_CELL value (one hot)
|
||||
|
||||
COMBINED:
|
||||
For all agent, value of Constant.OCCUPIED_CELL is added to a zero-value slice at the agents position (sum(SEPERATE))
|
||||
|
||||
LEVEL:
|
||||
The combined slice is added to the LEVEL-slice. (Agents appear as obstacle / wall)
|
||||
|
||||
NOT:
|
||||
The position of individual agents can not be read from the observation.
|
||||
"""
|
||||
|
||||
SEPERATE = 'seperate'
|
||||
COMBINED = 'combined'
|
||||
LEVEL = 'lvl'
|
||||
@ -11,24 +42,61 @@ class AgentRenderOptions(object):
|
||||
|
||||
|
||||
class MovementProperties(NamedTuple):
|
||||
"""
|
||||
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||
"""
|
||||
|
||||
"""Allow the manhattan style movement on a grid (move to cells that are connected by square edges)."""
|
||||
allow_square_movement: bool = True
|
||||
|
||||
"""Allow diagonal movement on the grid (move to cells that are connected by square corners)."""
|
||||
allow_diagonal_movement: bool = False
|
||||
|
||||
"""Allow the agent to just do nothing; not move (NO-OP)."""
|
||||
allow_no_op: bool = False
|
||||
|
||||
|
||||
class ObservationProperties(NamedTuple):
|
||||
# Todo: Add Description
|
||||
"""
|
||||
Property holder; for setting multiple related parameters through a single parameter. Comes with default values.
|
||||
"""
|
||||
|
||||
"""How to represent agents in the observation space. This may also alter the obs-shape."""
|
||||
render_agents: AgentRenderOptions = AgentRenderOptions.SEPERATE
|
||||
|
||||
"""Obserations are build per agent; whether the current agent should be represented in its own observation."""
|
||||
omit_agent_self: bool = True
|
||||
|
||||
"""Their might be the case you want to modify the agents obs-space, so that it can be used with additional obs.
|
||||
The additional slice can be filled with any number"""
|
||||
additional_agent_placeholder: Union[None, str, int] = None
|
||||
|
||||
"""Whether to cast shadows (make floortiles and items hidden).; """
|
||||
cast_shadows: bool = True
|
||||
|
||||
"""Frame Stacking is a methode do give some temporal information to the agents.
|
||||
This paramters controls how many "old-frames" """
|
||||
frames_to_stack: int = 0
|
||||
pomdp_r: int = 0
|
||||
|
||||
"""Specifies the radius (_r) of the agents field of view. Please note, that the agents grid cellis not taken
|
||||
accountance for. This means, that the resulting field of view diameter = `pomdp_r * 2 + 1`.
|
||||
A 'pomdp_r' of 0 always returns the full env == no partial observability."""
|
||||
pomdp_r: int = 2
|
||||
|
||||
"""Whether to place a visual encoding on walkable tiles around the doors. This is helpfull when the doors can be
|
||||
operated from their surrounding area. So the agent can more easily get a notion of where to choose the door option.
|
||||
However, this is not necesarry at all.
|
||||
"""
|
||||
indicate_door_area: bool = False
|
||||
|
||||
"""Whether to add the agents normalized global position as float values (2,1) to a seperate information slice.
|
||||
More optional informations are to come.
|
||||
"""
|
||||
show_global_position_info: bool = False
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
|
@ -1,119 +0,0 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory, RewardsDirt
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
from environments.factory.factory_dirt import Constants as c
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
if __name__ == '__main__':
|
||||
TRAIN_AGENT = True
|
||||
LOAD_AND_REPLAY = True
|
||||
record = True
|
||||
render = False
|
||||
|
||||
study_root_path = Path(__file__).parent.parent / 'experiment_out'
|
||||
|
||||
parameter_path = Path(__file__).parent.parent / 'environments' / 'factory' / 'levels' / 'parameters' / 'DirtyFactory-v0.yaml'
|
||||
|
||||
save_path = study_root_path / f'model.zip'
|
||||
|
||||
# Output folder
|
||||
|
||||
study_root_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
train_steps = 2*1e5
|
||||
frames_to_stack = 0
|
||||
|
||||
u = dict(
|
||||
show_global_position_info=True,
|
||||
pomdp_r=3,
|
||||
cast_shadows=True,
|
||||
allow_diagonal_movement=False,
|
||||
parse_doors=True,
|
||||
doors_have_area=False,
|
||||
done_at_collision=True
|
||||
)
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.SEPERATE,
|
||||
additional_agent_placeholder=None,
|
||||
omit_agent_self=True,
|
||||
frames_to_stack=frames_to_stack,
|
||||
pomdp_r=u['pomdp_r'], cast_shadows=u['cast_shadows'],
|
||||
show_global_position_info=u['show_global_position_info'])
|
||||
move_props = MovementProperties(allow_diagonal_movement=u['allow_diagonal_movement'],
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
dirt_props = DirtProperties(initial_dirt_ratio=0.35, initial_dirt_spawn_r_var=0.1,
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
rewards_dirt = RewardsDirt(CLEAN_UP_FAIL=-0.5, CLEAN_UP_VALID=1, CLEAN_UP_LAST_PIECE=5)
|
||||
factory_kwargs = dict(n_agents=1, max_steps=500, parse_doors=u['parse_doors'],
|
||||
level_name='rooms', doors_have_area=u['doors_have_area'],
|
||||
verbose=True,
|
||||
mv_prop=move_props,
|
||||
obs_prop=obs_props,
|
||||
rewards_dirt=rewards_dirt,
|
||||
done_at_collision=u['done_at_collision']
|
||||
)
|
||||
|
||||
# with (parameter_path).open('r') as f:
|
||||
# factory_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
# factory_kwargs.update(n_agents=1, done_at_collision=False, verbose=True)
|
||||
|
||||
if TRAIN_AGENT:
|
||||
env = DirtFactory(**factory_kwargs)
|
||||
callbacks = EnvMonitor(env)
|
||||
obs_shape = env.observation_space.shape
|
||||
|
||||
model = PPO("MlpPolicy", env, verbose=1, device='cpu')
|
||||
|
||||
model.learn(total_timesteps=train_steps, callback=callbacks)
|
||||
|
||||
callbacks.save_run(study_root_path / 'monitor.pick', auto_plotting_keys=['step_reward', 'collision'] + ['cleanup_valid', 'cleanup_fail']) # + env_plot_keys)
|
||||
|
||||
|
||||
model.save(save_path)
|
||||
|
||||
if LOAD_AND_REPLAY:
|
||||
with DirtFactory(**factory_kwargs) as env:
|
||||
env = EnvMonitor(env)
|
||||
env = EnvRecorder(env) if record else env
|
||||
obs_shape = env.observation_space.shape
|
||||
model = PPO.load(save_path)
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(10):
|
||||
env_state = env.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
actions = model.predict(env_state, deterministic=True)[0]
|
||||
env_state, step_r, done_bool, info_obj = env.step(actions)
|
||||
|
||||
rew += step_r
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
try:
|
||||
door = next(x for x in env.unwrapped.unwrapped.unwrapped[c.DOORS] if x.is_open)
|
||||
print('openDoor found')
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if done_bool:
|
||||
break
|
||||
print(
|
||||
f'Factory run {episode} done, steps taken {env.unwrapped.unwrapped.unwrapped._steps}, reward is:\n {rew}')
|
||||
|
||||
env.save_records(study_root_path / 'reload_recorder.pick', save_occupation_map=False)
|
||||
#env.save_run(study_root_path / 'reload_monitor.pick',
|
||||
# auto_plotting_keys=['step_reward', 'cleanup_valid', 'cleanup_fail'])
|
@ -76,10 +76,11 @@ def compare_seed_runs(run_path: Union[str, PathLike], use_tex: bool = False):
|
||||
skip_n = round(df_melted['Episode'].max() * 0.02)
|
||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||
|
||||
if run_path.is_dir():
|
||||
prepare_plot(run_path / f'{run_path}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
elif run_path.exists() and run_path.is_file():
|
||||
prepare_plot(run_path.parent / f'{run_path.parent}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
run_path.mkdir(parents=True, exist_ok=True)
|
||||
if run_path.exists() and run_path.is_file():
|
||||
prepare_plot(run_path.parent / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
else:
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
|
189
quickstart/combine_and_monitor_rerun.py
Normal file
189
quickstart/combine_and_monitor_rerun.py
Normal file
@ -0,0 +1,189 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
##############################################
|
||||
# keep this for stand alone script execution #
|
||||
##############################################
|
||||
from environments.factory.base.base_factory import BaseFactory
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
##############################################
|
||||
##############################################
|
||||
##############################################
|
||||
|
||||
|
||||
import simplejson
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.additional.combined_factories import DestBatteryFactory
|
||||
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||
from environments.factory.additional.item.factory_item import ItemFactory
|
||||
from environments.helpers import ObservationTranslator, ActionTranslator
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.utility_classes import ObservationProperties, AgentRenderOptions, MovementProperties
|
||||
|
||||
|
||||
def policy_model_kwargs():
|
||||
return dict(ent_coef=0.01)
|
||||
|
||||
|
||||
def dqn_model_kwargs():
|
||||
return dict(buffer_size=50000,
|
||||
learning_starts=64,
|
||||
batch_size=64,
|
||||
target_update_interval=5000,
|
||||
exploration_fraction=0.25,
|
||||
exploration_final_eps=0.025
|
||||
)
|
||||
|
||||
|
||||
def encapsule_env_factory(env_fctry, env_kwrgs):
|
||||
|
||||
def _init():
|
||||
with env_fctry(**env_kwrgs) as init_env:
|
||||
return init_env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
render = False
|
||||
# Define Global Env Parameters
|
||||
# Define properties object parameters
|
||||
factory_kwargs = dict(
|
||||
max_steps=400, parse_doors=True,
|
||||
level_name='rooms',
|
||||
doors_have_area=True, verbose=False,
|
||||
mv_prop=MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False),
|
||||
obs_prop=ObservationProperties(
|
||||
frames_to_stack=3,
|
||||
cast_shadows=True,
|
||||
omit_agent_self=True,
|
||||
render_agents=AgentRenderOptions.LEVEL,
|
||||
additional_agent_placeholder=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Bundle both environments with global kwargs and parameters
|
||||
# Todo: find a better solution, like outo module loading
|
||||
env_map = {'DirtFactory': DirtFactory,
|
||||
'ItemFactory': ItemFactory,
|
||||
'DestFactory': DestFactory,
|
||||
'DestBatteryFactory': DestBatteryFactory
|
||||
}
|
||||
env_names = list(env_map.keys())
|
||||
|
||||
# Put all your multi-seed agends in a single folder, we do not need specific names etc.
|
||||
available_models = dict()
|
||||
available_envs = dict()
|
||||
available_runs_kwargs = dict()
|
||||
available_runs_agents = dict()
|
||||
max_seed = 0
|
||||
# Define this folder
|
||||
combinations_path = Path('combinations')
|
||||
# Those are all differently trained combinations of mdoels, env and parameters
|
||||
for combination in (x for x in combinations_path.iterdir() if x.is_dir()):
|
||||
# These are all the models for this specific combination
|
||||
for model_run in (x for x in combination.iterdir() if x.is_dir()):
|
||||
model_name, env_name = model_run.name.split('_')[:2]
|
||||
if model_name not in available_models:
|
||||
available_models[model_name] = h.MODEL_MAP[model_name]
|
||||
if env_name not in available_envs:
|
||||
available_envs[env_name] = env_map[env_name]
|
||||
# Those are all available seeds
|
||||
for seed_run in (x for x in model_run.iterdir() if x.is_dir()):
|
||||
max_seed = max(int(seed_run.name.split('_')[0]), max_seed)
|
||||
# Read the env configuration from ROM
|
||||
with next(seed_run.glob('env_params.json')).open('r') as f:
|
||||
env_kwargs = simplejson.load(f)
|
||||
available_runs_kwargs[seed_run.name] = env_kwargs
|
||||
# Read the trained model_path from ROM
|
||||
model_path = next(seed_run.glob('model.zip'))
|
||||
available_runs_agents[seed_run.name] = model_path
|
||||
|
||||
# We start by combining all SAME MODEL CLASSES per available Seed, across ALL available ENVIRONMENTS.
|
||||
for model_name, model_cls in available_models.items():
|
||||
for seed in range(max_seed):
|
||||
combined_env_kwargs = dict()
|
||||
model_paths = list()
|
||||
comparable_runs = {key: val for key, val in available_runs_kwargs.items() if (
|
||||
key.startswith(str(seed)) and model_name in key and key != 'key')
|
||||
}
|
||||
for name, run_kwargs in comparable_runs.items():
|
||||
# Select trained agent as a candidate:
|
||||
model_paths.append(available_runs_agents[name])
|
||||
# Sort Env Kwars:
|
||||
for key, val in run_kwargs.items():
|
||||
if key not in combined_env_kwargs:
|
||||
combined_env_kwargs.update(dict(key=val))
|
||||
else:
|
||||
assert combined_env_kwargs[key] == val, "Check the combinations you try to make!"
|
||||
|
||||
# Update and combine all kwargs to account for multiple agents etc.
|
||||
# We cannot capture all configuration cases!
|
||||
for key, val in factory_kwargs.items():
|
||||
if key not in combined_env_kwargs:
|
||||
combined_env_kwargs[key] = val
|
||||
else:
|
||||
assert combined_env_kwargs[key] == val
|
||||
del combined_env_kwargs['key']
|
||||
combined_env_kwargs.update(n_agents=len(comparable_runs))
|
||||
with type("CombinedEnv", tuple(available_envs.values()), {})(**combined_env_kwargs) as combEnv:
|
||||
# EnvMonitor Init
|
||||
comb = f'comb_{model_name}_{seed}'
|
||||
comb_monitor_path = combinations_path / comb / f'{comb}_monitor.pick'
|
||||
comb_recorder_path = combinations_path / comb / f'{comb}_recorder.json'
|
||||
comb_monitor_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
monitoredCombEnv = EnvMonitor(combEnv, filepath=comb_monitor_path)
|
||||
monitoredCombEnv = EnvRecorder(monitoredCombEnv, filepath=comb_recorder_path, freq=1)
|
||||
|
||||
# Evaluation starts here #####################################################
|
||||
# Load all models
|
||||
loaded_models = [available_models[model_name].load(model_path) for model_path in model_paths]
|
||||
obs_translators = ObservationTranslator(
|
||||
monitoredCombEnv.named_observation_space,
|
||||
*[agent.named_observation_space for agent in loaded_models],
|
||||
placeholder_fill_value='n')
|
||||
act_translators = ActionTranslator(
|
||||
monitoredCombEnv.named_action_space,
|
||||
*(agent.named_action_space for agent in loaded_models)
|
||||
)
|
||||
|
||||
for episode in range(1):
|
||||
obs = monitoredCombEnv.reset()
|
||||
if render: monitoredCombEnv.render()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
actions = []
|
||||
for i, model in enumerate(loaded_models):
|
||||
pred = model.predict(obs_translators.translate_observation(i, obs[i]))[0]
|
||||
actions.append(act_translators.translate_action(i, pred))
|
||||
|
||||
obs, step_r, done_bool, info_obj = monitoredCombEnv.step(actions)
|
||||
|
||||
rew += step_r
|
||||
if render: monitoredCombEnv.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
# Eval monitor outputs are automatically stored by the monitor object
|
||||
# TODO: Plotting
|
||||
monitoredCombEnv.save_records()
|
||||
monitoredCombEnv.save_run()
|
||||
pass
|
203
quickstart/single_agent_train_battery_target_env.py
Normal file
203
quickstart/single_agent_train_battery_target_env.py
Normal file
@ -0,0 +1,203 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
import simplejson
|
||||
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
# This is needed, when you put this file in a subfolder.
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||
from environments.factory.additional.btry.btry_util import BatteryProperties
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
from environments.factory.additional.combined_factories import DestBatteryFactory
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
|
||||
from plotting.compare_runs import compare_seed_runs
|
||||
|
||||
"""
|
||||
Welcome to this quick start file. Here we will see how to:
|
||||
0. Setup I/O Paths
|
||||
1. Setup parameters for the environments (dirt-factory).
|
||||
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
Run the training.
|
||||
3. Save env and agent for later analysis.
|
||||
4. Load the agent from drive
|
||||
5. Rendering the env with a run of the trained agent.
|
||||
6. Plot metrics
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
#########################################################
|
||||
# 0. Setup I/O Paths
|
||||
# Define some general parameters
|
||||
train_steps = 1e6
|
||||
n_seeds = 3
|
||||
model_class = sb3.PPO
|
||||
env_class = DestBatteryFactory
|
||||
|
||||
env_params_json = 'env_params.json'
|
||||
|
||||
# Define a global studi save path
|
||||
start_time = int(time.time())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||
exp_path = study_root_path / identifier
|
||||
|
||||
#########################################################
|
||||
# 1. Setup parameters for the environments (dirt-factory).
|
||||
|
||||
|
||||
# Define property object parameters.
|
||||
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||
omit_agent_self=True, # This is default
|
||||
additional_agent_placeholder=None, # We will not take care of future agents
|
||||
frames_to_stack=3, # To give the agent a notion of time
|
||||
pomdp_r=2 # the agents view-radius
|
||||
)
|
||||
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||
allow_square_movement=True, # Manhattan (edges)
|
||||
allow_no_op=False) # Pause movement (do nothing)
|
||||
|
||||
# 'DirtProperties' control if and how dirt is spawned
|
||||
# TODO: Comments
|
||||
dest_props = DestProperties(
|
||||
n_dests = 2, # How many destinations are there
|
||||
dwell_time = 0, # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency = 0,
|
||||
spawn_in_other_zone = True, #
|
||||
spawn_mode = DestModeOptions.DONE,
|
||||
)
|
||||
btry_props = BatteryProperties(
|
||||
initial_charge = 0.9, #
|
||||
charge_rate = 0.4, #
|
||||
charge_locations = 3, #
|
||||
per_action_costs = 0.01,
|
||||
done_when_discharged = True,
|
||||
multi_charge = False,
|
||||
)
|
||||
|
||||
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||
# TODO: Comments
|
||||
factory_kwargs = dict(n_agents=1,
|
||||
max_steps=400,
|
||||
parse_doors=True,
|
||||
level_name='rooms',
|
||||
doors_have_area=True, #
|
||||
verbose=False,
|
||||
mv_prop=move_props, # See Above
|
||||
obs_prop=obs_props, # See Above
|
||||
done_at_collision=True,
|
||||
dest_prop=dest_props,
|
||||
btry_prop=btry_props
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
agent_kwargs = dict()
|
||||
|
||||
|
||||
#########################################################
|
||||
# Run the Training
|
||||
for seed in range(n_seeds):
|
||||
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||
env_kwargs = factory_kwargs.copy()
|
||||
env_kwargs.update(env_seed=seed)
|
||||
|
||||
# Output folder
|
||||
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||
seed_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Parameter Storage
|
||||
param_path = seed_path / env_params_json
|
||||
# Observation (measures) Storage
|
||||
monitor_path = seed_path / 'monitor.pick'
|
||||
recorder_path = seed_path / 'recorder.json'
|
||||
# Model save Path for the trained model
|
||||
model_save_path = seed_path / f'model.zip'
|
||||
|
||||
# Env Init & Model kwargs definition
|
||||
with env_class(**env_kwargs) as env_factory:
|
||||
|
||||
# EnvMonitor Init
|
||||
env_monitor_callback = EnvMonitor(env_factory)
|
||||
|
||||
# EnvRecorder Init
|
||||
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||
|
||||
# Model Init
|
||||
model = model_class("MlpPolicy", env_factory, verbose=1, seed=seed, device='cpu')
|
||||
|
||||
# Model train
|
||||
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||
|
||||
#########################################################
|
||||
# 3. Save env and agent for later analysis.
|
||||
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||
model.named_observation_space = env_factory.named_observation_space
|
||||
model.named_action_space = env_factory.named_action_space
|
||||
model.save(model_save_path)
|
||||
env_factory.save_params(param_path)
|
||||
env_monitor_callback.save_run(monitor_path)
|
||||
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||
|
||||
# Compare performance runs, for each seed within a model
|
||||
try:
|
||||
compare_seed_runs(exp_path, use_tex=False)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Train ends here ############################################################
|
||||
|
||||
# Evaluation starts here #####################################################
|
||||
# First Iterate over every model and monitor "as trained"
|
||||
print('Start Measurement Tracking')
|
||||
# For trained policy in study_root_path / identifier
|
||||
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||
|
||||
# retrieve model class
|
||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||
# Load the agent agent
|
||||
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||
# Load old env kwargs
|
||||
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||
env_kwargs = simplejson.load(f)
|
||||
# Make the env stop ar collisions
|
||||
# (you only want to have a single collision per episode hence the statistics)
|
||||
env_kwargs.update(done_at_collision=True)
|
||||
|
||||
# Init Env
|
||||
with env_class(**env_kwargs) as env_factory:
|
||||
monitored_env_factory = EnvMonitor(env_factory)
|
||||
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(100):
|
||||
# noinspection PyRedeclaration
|
||||
env_state = monitored_env_factory.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
action = model.predict(env_state, deterministic=True)[0]
|
||||
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||
rew += step_r
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||
print('Measurements Done')
|
193
quickstart/single_agent_train_dest_env.py
Normal file
193
quickstart/single_agent_train_dest_env.py
Normal file
@ -0,0 +1,193 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
import simplejson
|
||||
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
# This is needed, when you put this file in a subfolder.
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
from environments.factory.additional.dest.factory_dest import DestFactory
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
|
||||
from plotting.compare_runs import compare_seed_runs
|
||||
|
||||
"""
|
||||
Welcome to this quick start file. Here we will see how to:
|
||||
0. Setup I/O Paths
|
||||
1. Setup parameters for the environments (dest-factory).
|
||||
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
Run the training.
|
||||
3. Save env and agent for later analysis.
|
||||
4. Load the agent from drive
|
||||
5. Rendering the env with a run of the trained agent.
|
||||
6. Plot metrics
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
#########################################################
|
||||
# 0. Setup I/O Paths
|
||||
# Define some general parameters
|
||||
train_steps = 1e6
|
||||
n_seeds = 3
|
||||
model_class = sb3.PPO
|
||||
env_class = DestFactory
|
||||
|
||||
env_params_json = 'env_params.json'
|
||||
|
||||
# Define a global studi save path
|
||||
start_time = int(time.time())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||
exp_path = study_root_path / identifier
|
||||
|
||||
#########################################################
|
||||
# 1. Setup parameters for the environments (dest-factory).
|
||||
|
||||
|
||||
# Define property object parameters.
|
||||
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||
omit_agent_self=True, # This is default
|
||||
additional_agent_placeholder=None, # We will not take care of future agents
|
||||
frames_to_stack=3, # To give the agent a notion of time
|
||||
pomdp_r=2 # the agents view-radius
|
||||
)
|
||||
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||
allow_square_movement=True, # Manhattan (edges)
|
||||
allow_no_op=False) # Pause movement (do nothing)
|
||||
|
||||
# 'DestProperties' control if and how dest is spawned
|
||||
# TODO: Comments
|
||||
dest_props = DestProperties(
|
||||
n_dests = 2, # How many destinations are there
|
||||
dwell_time = 0, # How long does the agent need to "wait" on a destination
|
||||
spawn_frequency = 0,
|
||||
spawn_in_other_zone = True, #
|
||||
spawn_mode = DestModeOptions.DONE,
|
||||
)
|
||||
|
||||
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||
# TODO: Comments
|
||||
factory_kwargs = dict(n_agents=1,
|
||||
max_steps=400,
|
||||
parse_doors=True,
|
||||
level_name='rooms',
|
||||
doors_have_area=True, #
|
||||
verbose=False,
|
||||
mv_prop=move_props, # See Above
|
||||
obs_prop=obs_props, # See Above
|
||||
done_at_collision=True,
|
||||
dest_prop=dest_props
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
agent_kwargs = dict()
|
||||
|
||||
|
||||
#########################################################
|
||||
# Run the Training
|
||||
for seed in range(n_seeds):
|
||||
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||
env_kwargs = factory_kwargs.copy()
|
||||
env_kwargs.update(env_seed=seed)
|
||||
|
||||
# Output folder
|
||||
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||
seed_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Parameter Storage
|
||||
param_path = seed_path / env_params_json
|
||||
# Observation (measures) Storage
|
||||
monitor_path = seed_path / 'monitor.pick'
|
||||
recorder_path = seed_path / 'recorder.json'
|
||||
# Model save Path for the trained model
|
||||
model_save_path = seed_path / f'model.zip'
|
||||
|
||||
# Env Init & Model kwargs definition
|
||||
with env_class(**env_kwargs) as env_factory:
|
||||
|
||||
# EnvMonitor Init
|
||||
env_monitor_callback = EnvMonitor(env_factory)
|
||||
|
||||
# EnvRecorder Init
|
||||
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||
|
||||
# Model Init
|
||||
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||
|
||||
# Model train
|
||||
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||
|
||||
#########################################################
|
||||
# 3. Save env and agent for later analysis.
|
||||
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||
model.named_observation_space = env_factory.named_observation_space
|
||||
model.named_action_space = env_factory.named_action_space
|
||||
model.save(model_save_path)
|
||||
env_factory.save_params(param_path)
|
||||
env_monitor_callback.save_run(monitor_path)
|
||||
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||
|
||||
# Compare performance runs, for each seed within a model
|
||||
try:
|
||||
compare_seed_runs(exp_path, use_tex=False)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Train ends here ############################################################
|
||||
|
||||
# Evaluation starts here #####################################################
|
||||
# First Iterate over every model and monitor "as trained"
|
||||
print('Start Measurement Tracking')
|
||||
# For trained policy in study_root_path / identifier
|
||||
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||
|
||||
# retrieve model class
|
||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||
# Load the agent agent
|
||||
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||
# Load old env kwargs
|
||||
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||
env_kwargs = simplejson.load(f)
|
||||
# Make the env stop ar collisions
|
||||
# (you only want to have a single collision per episode hence the statistics)
|
||||
env_kwargs.update(done_at_collision=True)
|
||||
|
||||
# Init Env
|
||||
with env_class(**env_kwargs) as env_factory:
|
||||
monitored_env_factory = EnvMonitor(env_factory)
|
||||
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(100):
|
||||
# noinspection PyRedeclaration
|
||||
env_state = monitored_env_factory.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
action = model.predict(env_state, deterministic=True)[0]
|
||||
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||
rew += step_r
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||
print('Measurements Done')
|
195
quickstart/single_agent_train_dirt_env.py
Normal file
195
quickstart/single_agent_train_dirt_env.py
Normal file
@ -0,0 +1,195 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
import simplejson
|
||||
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
# This is needed, when you put this file in a subfolder.
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
from environments.factory.additional.dirt.dirt_util import DirtProperties
|
||||
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
|
||||
from plotting.compare_runs import compare_seed_runs
|
||||
|
||||
"""
|
||||
Welcome to this quick start file. Here we will see how to:
|
||||
0. Setup I/O Paths
|
||||
1. Setup parameters for the environments (dirt-factory).
|
||||
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
Run the training.
|
||||
3. Save env and agent for later analysis.
|
||||
4. Load the agent from drive
|
||||
5. Rendering the env with a run of the trained agent.
|
||||
6. Plot metrics
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
#########################################################
|
||||
# 0. Setup I/O Paths
|
||||
# Define some general parameters
|
||||
train_steps = 1e6
|
||||
n_seeds = 3
|
||||
model_class = sb3.PPO
|
||||
env_class = DirtFactory
|
||||
|
||||
env_params_json = 'env_params.json'
|
||||
|
||||
# Define a global studi save path
|
||||
start_time = int(time.time())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||
exp_path = study_root_path / identifier
|
||||
|
||||
#########################################################
|
||||
# 1. Setup parameters for the environments (dirt-factory).
|
||||
|
||||
|
||||
# Define property object parameters.
|
||||
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||
omit_agent_self=True, # This is default
|
||||
additional_agent_placeholder=None, # We will not take care of future agents
|
||||
frames_to_stack=3, # To give the agent a notion of time
|
||||
pomdp_r=2 # the agents view-radius
|
||||
)
|
||||
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||
allow_square_movement=True, # Manhattan (edges)
|
||||
allow_no_op=False) # Pause movement (do nothing)
|
||||
|
||||
# 'DirtProperties' control if and how dirt is spawned
|
||||
# TODO: Comments
|
||||
dirt_props = DirtProperties(initial_dirt_ratio=0.35,
|
||||
initial_dirt_spawn_r_var=0.1,
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1,
|
||||
max_global_amount=20,
|
||||
max_local_amount=1,
|
||||
spawn_frequency=0,
|
||||
max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0)
|
||||
|
||||
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||
# TODO: Comments
|
||||
factory_kwargs = dict(n_agents=1,
|
||||
max_steps=400,
|
||||
parse_doors=True,
|
||||
level_name='rooms',
|
||||
doors_have_area=True, #
|
||||
verbose=False,
|
||||
mv_prop=move_props, # See Above
|
||||
obs_prop=obs_props, # See Above
|
||||
done_at_collision=True,
|
||||
dirt_prop=dirt_props
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
agent_kwargs = dict()
|
||||
|
||||
|
||||
#########################################################
|
||||
# Run the Training
|
||||
for seed in range(n_seeds):
|
||||
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||
env_kwargs = factory_kwargs.copy()
|
||||
env_kwargs.update(env_seed=seed)
|
||||
|
||||
# Output folder
|
||||
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||
seed_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Parameter Storage
|
||||
param_path = seed_path / env_params_json
|
||||
# Observation (measures) Storage
|
||||
monitor_path = seed_path / 'monitor.pick'
|
||||
recorder_path = seed_path / 'recorder.json'
|
||||
# Model save Path for the trained model
|
||||
model_save_path = seed_path / f'model.zip'
|
||||
|
||||
# Env Init & Model kwargs definition
|
||||
with env_class(**env_kwargs) as env_factory:
|
||||
|
||||
# EnvMonitor Init
|
||||
env_monitor_callback = EnvMonitor(env_factory)
|
||||
|
||||
# EnvRecorder Init
|
||||
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||
|
||||
# Model Init
|
||||
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||
|
||||
# Model train
|
||||
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||
|
||||
#########################################################
|
||||
# 3. Save env and agent for later analysis.
|
||||
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||
model.named_observation_space = env_factory.named_observation_space
|
||||
model.named_action_space = env_factory.named_action_space
|
||||
model.save(model_save_path)
|
||||
env_factory.save_params(param_path)
|
||||
env_monitor_callback.save_run(monitor_path)
|
||||
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||
|
||||
# Compare performance runs, for each seed within a model
|
||||
try:
|
||||
compare_seed_runs(exp_path, use_tex=False)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Train ends here ############################################################
|
||||
|
||||
# Evaluation starts here #####################################################
|
||||
# First Iterate over every model and monitor "as trained"
|
||||
print('Start Measurement Tracking')
|
||||
# For trained policy in study_root_path / identifier
|
||||
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||
|
||||
# retrieve model class
|
||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||
# Load the agent agent
|
||||
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||
# Load old env kwargs
|
||||
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||
env_kwargs = simplejson.load(f)
|
||||
# Make the env stop ar collisions
|
||||
# (you only want to have a single collision per episode hence the statistics)
|
||||
env_kwargs.update(done_at_collision=True)
|
||||
|
||||
# Init Env
|
||||
with env_class(**env_kwargs) as env_factory:
|
||||
monitored_env_factory = EnvMonitor(env_factory)
|
||||
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(100):
|
||||
# noinspection PyRedeclaration
|
||||
env_state = monitored_env_factory.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
action = model.predict(env_state, deterministic=True)[0]
|
||||
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||
rew += step_r
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||
print('Measurements Done')
|
191
quickstart/single_agent_train_item_env.py
Normal file
191
quickstart/single_agent_train_item_env.py
Normal file
@ -0,0 +1,191 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
import simplejson
|
||||
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
# This is needed, when you put this file in a subfolder.
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.additional.item.factory_item import ItemFactory
|
||||
from environments.factory.additional.item.item_util import ItemProperties
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
|
||||
from plotting.compare_runs import compare_seed_runs
|
||||
|
||||
"""
|
||||
Welcome to this quick start file. Here we will see how to:
|
||||
0. Setup I/O Paths
|
||||
1. Setup parameters for the environments (item-factory).
|
||||
2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
Run the training.
|
||||
3. Save env and agent for later analysis.
|
||||
4. Load the agent from drive
|
||||
5. Rendering the env with a run of the trained agent.
|
||||
6. Plot metrics
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
#########################################################
|
||||
# 0. Setup I/O Paths
|
||||
# Define some general parameters
|
||||
train_steps = 1e6
|
||||
n_seeds = 3
|
||||
model_class = sb3.PPO
|
||||
env_class = ItemFactory
|
||||
|
||||
env_params_json = 'env_params.json'
|
||||
|
||||
# Define a global studi save path
|
||||
start_time = int(time.time())
|
||||
study_root_path = Path(__file__).parent.parent / 'study_out' / f'{Path(__file__).stem}_{start_time}'
|
||||
# Create an identifier, which is unique for every combination and easy to read in filesystem
|
||||
identifier = f'{model_class.__name__}_{env_class.__name__}_{start_time}'
|
||||
exp_path = study_root_path / identifier
|
||||
|
||||
#########################################################
|
||||
# 1. Setup parameters for the environments (item-factory).
|
||||
#
|
||||
# Define property object parameters.
|
||||
# 'ObservationProperties' are for specifying how the agent sees the env.
|
||||
obs_props = ObservationProperties(render_agents=AgentRenderOptions.NOT, # Agents won`t be shown in the obs at all
|
||||
omit_agent_self=True, # This is default
|
||||
additional_agent_placeholder=None, # We will not take care of future agents
|
||||
frames_to_stack=3, # To give the agent a notion of time
|
||||
pomdp_r=2 # the agents view-radius
|
||||
)
|
||||
# 'MovementProperties' are for specifying how the agent is allowed to move in the env.
|
||||
move_props = MovementProperties(allow_diagonal_movement=True, # Euclidean style (vertices)
|
||||
allow_square_movement=True, # Manhattan (edges)
|
||||
allow_no_op=False) # Pause movement (do nothing)
|
||||
|
||||
# 'ItemProperties' control if and how item is spawned
|
||||
# TODO: Comments
|
||||
item_props = ItemProperties(
|
||||
n_items = 7, # How many items are there at the same time
|
||||
spawn_frequency = 50, # Spawn Frequency in Steps
|
||||
n_drop_off_locations = 10, # How many DropOff locations are there at the same time
|
||||
max_dropoff_storage_size = 0, # How many items are needed until the dropoff is full
|
||||
max_agent_inventory_capacity = 5, # How many items are needed until the agent inventory is full)
|
||||
)
|
||||
|
||||
# These are the EnvKwargs for initializing the env class, holding all former parameter-classes
|
||||
# TODO: Comments
|
||||
factory_kwargs = dict(n_agents=1,
|
||||
max_steps=400,
|
||||
parse_doors=True,
|
||||
level_name='rooms',
|
||||
doors_have_area=True, #
|
||||
verbose=False,
|
||||
mv_prop=move_props, # See Above
|
||||
obs_prop=obs_props, # See Above
|
||||
done_at_collision=True,
|
||||
item_prop=item_props
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# 2. Setup parameters for the agent training (SB3: PPO) and save metrics.
|
||||
agent_kwargs = dict()
|
||||
|
||||
#########################################################
|
||||
# Run the Training
|
||||
for seed in range(n_seeds):
|
||||
# Make a copy if you want to alter things in the training loop; like the seed.
|
||||
env_kwargs = factory_kwargs.copy()
|
||||
env_kwargs.update(env_seed=seed)
|
||||
|
||||
# Output folder
|
||||
seed_path = exp_path / f'{str(seed)}_{identifier}'
|
||||
seed_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Parameter Storage
|
||||
param_path = seed_path / env_params_json
|
||||
# Observation (measures) Storage
|
||||
monitor_path = seed_path / 'monitor.pick'
|
||||
recorder_path = seed_path / 'recorder.json'
|
||||
# Model save Path for the trained model
|
||||
model_save_path = seed_path / f'model.zip'
|
||||
|
||||
# Env Init & Model kwargs definition
|
||||
with ItemFactory(**env_kwargs) as env_factory:
|
||||
|
||||
# EnvMonitor Init
|
||||
env_monitor_callback = EnvMonitor(env_factory)
|
||||
|
||||
# EnvRecorder Init
|
||||
env_recorder_callback = EnvRecorder(env_factory, freq=int(train_steps / 400 / 10))
|
||||
|
||||
# Model Init
|
||||
model = model_class("MlpPolicy", env_factory,verbose=1, seed=seed, device='cpu')
|
||||
|
||||
# Model train
|
||||
model.learn(total_timesteps=int(train_steps), callback=[env_monitor_callback, env_recorder_callback])
|
||||
|
||||
#########################################################
|
||||
# 3. Save env and agent for later analysis.
|
||||
# Save the trained Model, the monitor (env measures) and the env parameters
|
||||
model.named_observation_space = env_factory.named_observation_space
|
||||
model.named_action_space = env_factory.named_action_space
|
||||
model.save(model_save_path)
|
||||
env_factory.save_params(param_path)
|
||||
env_monitor_callback.save_run(monitor_path)
|
||||
env_recorder_callback.save_records(recorder_path, save_occupation_map=False)
|
||||
|
||||
# Compare performance runs, for each seed within a model
|
||||
try:
|
||||
compare_seed_runs(exp_path, use_tex=False)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Train ends here ############################################################
|
||||
|
||||
# Evaluation starts here #####################################################
|
||||
# First Iterate over every model and monitor "as trained"
|
||||
print('Start Measurement Tracking')
|
||||
# For trained policy in study_root_path / identifier
|
||||
for policy_path in [x for x in exp_path.iterdir() if x.is_dir()]:
|
||||
|
||||
# retrieve model class
|
||||
model_cls = next(val for key, val in h.MODEL_MAP.items() if key in policy_path.parent.name)
|
||||
# Load the agent agent
|
||||
model = model_cls.load(policy_path / 'model.zip', device='cpu')
|
||||
# Load old env kwargs
|
||||
with next(policy_path.glob(env_params_json)).open('r') as f:
|
||||
env_kwargs = simplejson.load(f)
|
||||
# Make the env stop ar collisions
|
||||
# (you only want to have a single collision per episode hence the statistics)
|
||||
env_kwargs.update(done_at_collision=True)
|
||||
|
||||
# Init Env
|
||||
with ItemFactory(**env_kwargs) as env_factory:
|
||||
monitored_env_factory = EnvMonitor(env_factory)
|
||||
|
||||
# Evaluation Loop for i in range(n Episodes)
|
||||
for episode in range(100):
|
||||
# noinspection PyRedeclaration
|
||||
env_state = monitored_env_factory.reset()
|
||||
rew, done_bool = 0, False
|
||||
while not done_bool:
|
||||
action = model.predict(env_state, deterministic=True)[0]
|
||||
env_state, step_r, done_bool, info_obj = monitored_env_factory.step(action)
|
||||
rew += step_r
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {episode} done, reward is:\n {rew}')
|
||||
monitored_env_factory.save_run(filepath=policy_path / 'eval_run_monitor.pick')
|
||||
print('Measurements Done')
|
@ -4,9 +4,9 @@ from pathlib import Path
|
||||
import yaml
|
||||
from stable_baselines3 import A2C, PPO, DQN
|
||||
|
||||
from environments.factory.factory_dirt import Constants as c
|
||||
from environments.factory.additional.dirt.dirt_util import Constants
|
||||
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.additional.dirt.factory_dirt import DirtFactory
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
|
||||
@ -23,7 +23,7 @@ if __name__ == '__main__':
|
||||
seed = 13
|
||||
n_agents = 1
|
||||
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||
out_path = Path('study_out/reload')
|
||||
out_path = Path('quickstart/combinations/single_agent_train_dirt_env_1659374984/PPO_DirtFactory_1659374984/0_PPO_DirtFactory_1659374984/')
|
||||
model_path = out_path
|
||||
|
||||
with (out_path / f'env_params.json').open('r') as f:
|
||||
@ -62,7 +62,7 @@ if __name__ == '__main__':
|
||||
if render:
|
||||
env.render()
|
||||
try:
|
||||
door = next(x for x in env.unwrapped.unwrapped[c.DOORS] if x.is_open)
|
||||
door = next(x for x in env.unwrapped.unwrapped[Constants.DOORS] if x.is_open)
|
||||
print('openDoor found')
|
||||
except StopIteration:
|
||||
pass
|
||||
|
@ -19,9 +19,11 @@ import simplejson
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
|
||||
from environments import helpers as h
|
||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.dirt_util import DirtProperties
|
||||
from environments.factory.combined_factories import DirtItemFactory
|
||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
||||
from environments.factory.factory_item import ItemFactory
|
||||
from environments.factory.additional.item.item_util import ItemProperties
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
import pickle
|
||||
@ -215,7 +217,7 @@ if __name__ == '__main__':
|
||||
clean_amount=0.34,
|
||||
max_spawn_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=0, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||
dirt_smear_amount=0.0)
|
||||
item_props = ItemProperties(n_items=10,
|
||||
spawn_frequency=30, n_drop_off_locations=2,
|
||||
max_agent_inventory_capacity=15)
|
||||
@ -349,6 +351,7 @@ if __name__ == '__main__':
|
||||
# Env Init & Model kwargs definition
|
||||
if model_cls.__name__ in ["PPO", "A2C"]:
|
||||
# env_factory = env_class(**env_kwargs)
|
||||
|
||||
env_factory = SubprocVecEnv([encapsule_env_factory(env_class, env_kwargs)
|
||||
for _ in range(6)], start_method="spawn")
|
||||
model_kwargs = policy_model_kwargs()
|
||||
|
@ -20,9 +20,12 @@ import simplejson
|
||||
from environments.helpers import ActionTranslator, ObservationTranslator
|
||||
from environments.logging.recorder import EnvRecorder
|
||||
from environments import helpers as h
|
||||
from environments.factory.factory_dirt import DirtProperties, DirtFactory
|
||||
from environments.factory.factory_item import ItemProperties, ItemFactory
|
||||
from environments.factory.factory_dest import DestProperties, DestFactory, DestModeOptions
|
||||
from environments.factory.factory_dirt import DirtFactory
|
||||
from environments.factory.dirt_util import DirtProperties
|
||||
from environments.factory.factory_item import ItemFactory
|
||||
from environments.factory.additional.item.item_util import ItemProperties
|
||||
from environments.factory.factory_dest import DestFactory
|
||||
from environments.factory.additional.dest.dest_util import DestModeOptions, DestProperties
|
||||
from environments.factory.combined_factories import DirtDestItemFactory
|
||||
from environments.logging.envmonitor import EnvMonitor
|
||||
from environments.utility_classes import MovementProperties, ObservationProperties, AgentRenderOptions
|
||||
@ -213,7 +216,8 @@ if __name__ == '__main__':
|
||||
env_factory.save_params(param_path)
|
||||
|
||||
# EnvMonitor Init
|
||||
callbacks = [EnvMonitor(env_factory)]
|
||||
env_monitor = EnvMonitor(env_factory)
|
||||
callbacks = [env_monitor]
|
||||
|
||||
# Model Init
|
||||
model = model_cls("MlpPolicy", env_factory, **policy_model_kwargs,
|
||||
@ -233,7 +237,7 @@ if __name__ == '__main__':
|
||||
model.save(save_path)
|
||||
|
||||
# Monitor Save
|
||||
callbacks[0].save_run(combination_path / 'monitor.pick',
|
||||
env_monitor.save_run(combination_path / 'monitor.pick',
|
||||
auto_plotting_keys=['step_reward', 'collision'] + env_plot_keys)
|
||||
|
||||
# Better be save then sorry: Clean up!
|
||||
|
Reference in New Issue
Block a user