mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-06-17 23:22:16 +02:00
New Szenario "Two_Rooms_One_Door"
This commit is contained in:
@@ -21,7 +21,6 @@ Agents:
|
|||||||
- Items
|
- Items
|
||||||
- Inventory
|
- Inventory
|
||||||
- DropOffLocations
|
- DropOffLocations
|
||||||
- Machines
|
|
||||||
- Maintainers
|
- Maintainers
|
||||||
Entities:
|
Entities:
|
||||||
Batteries: {}
|
Batteries: {}
|
||||||
@@ -79,10 +78,6 @@ Rules:
|
|||||||
n_items: 5
|
n_items: 5
|
||||||
n_locations: 5
|
n_locations: 5
|
||||||
spawn_frequency: 15
|
spawn_frequency: 15
|
||||||
MachineRule:
|
|
||||||
n_machines: 2
|
|
||||||
MaintenanceRule:
|
|
||||||
n_maintainer: 1
|
|
||||||
MaxStepsReached:
|
MaxStepsReached:
|
||||||
max_steps: 500
|
max_steps: 500
|
||||||
# AgentSingleZonePlacement:
|
# AgentSingleZonePlacement:
|
||||||
|
|||||||
@@ -1,3 +1,31 @@
|
|||||||
|
General:
|
||||||
|
env_seed: 69
|
||||||
|
individual_rewards: true
|
||||||
|
level_name: two_rooms
|
||||||
|
pomdp_r: 3
|
||||||
|
verbose: false
|
||||||
|
|
||||||
|
Entities:
|
||||||
|
BoundDestinations: {}
|
||||||
|
ReachedDestinations: {}
|
||||||
|
Doors: {}
|
||||||
|
GlobalPositions: {}
|
||||||
|
Zones: {}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
# Init:
|
||||||
|
AssignGlobalPositions: {}
|
||||||
|
ZoneInit: {}
|
||||||
|
AgentSingleZonePlacement: {}
|
||||||
|
IndividualDestinationZonePlacement: {}
|
||||||
|
# Env Rules
|
||||||
|
MaxStepsReached:
|
||||||
|
max_steps: 10
|
||||||
|
Collision:
|
||||||
|
done_at_collisions: false
|
||||||
|
DoorAutoClose:
|
||||||
|
close_frequency: 10
|
||||||
|
|
||||||
Agents:
|
Agents:
|
||||||
Wolfgang:
|
Wolfgang:
|
||||||
Actions:
|
Actions:
|
||||||
@@ -21,27 +49,4 @@ Agents:
|
|||||||
- Other
|
- Other
|
||||||
- Walls
|
- Walls
|
||||||
- BoundDestination
|
- BoundDestination
|
||||||
- Doors
|
- Doors
|
||||||
Entities:
|
|
||||||
BoundDestinations: {}
|
|
||||||
ReachedDestinations: {}
|
|
||||||
Doors: {}
|
|
||||||
GlobalPositions: {}
|
|
||||||
Zones: {}
|
|
||||||
|
|
||||||
General:
|
|
||||||
env_seed: 69
|
|
||||||
individual_rewards: true
|
|
||||||
level_name: two_rooms
|
|
||||||
pomdp_r: 3
|
|
||||||
verbose: false
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
Collision:
|
|
||||||
done_at_collisions: false
|
|
||||||
AssignGlobalPositions: {}
|
|
||||||
DoorAutoClose:
|
|
||||||
close_frequency: 10
|
|
||||||
ZoneInit: {}
|
|
||||||
AgentSingleZonePlacement: {}
|
|
||||||
IndividualDestinationZonePlacement: {}
|
|
||||||
@@ -62,11 +62,14 @@ class Object:
|
|||||||
|
|
||||||
def add_observer(self, observer):
|
def add_observer(self, observer):
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
observer.notify_change_pos(self)
|
observer.notify_add_entity(self)
|
||||||
|
|
||||||
def del_observer(self, observer):
|
def del_observer(self, observer):
|
||||||
self.observers.remove(observer)
|
self.observers.remove(observer)
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
return dict()
|
||||||
|
|
||||||
|
|
||||||
class EnvObject(Object):
|
class EnvObject(Object):
|
||||||
|
|
||||||
@@ -128,3 +131,6 @@ class EnvObject(Object):
|
|||||||
self._collection.delete_env_object(self)
|
self._collection.delete_env_object(self)
|
||||||
self._collection = other_collection
|
self._collection = other_collection
|
||||||
return self._collection == other_collection
|
return self._collection == other_collection
|
||||||
|
|
||||||
|
def summarize_state(self):
|
||||||
|
return dict(name=str(self.name))
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ import marl_factory_grid.environment.constants as c
|
|||||||
|
|
||||||
from marl_factory_grid.utils.states import Gamestate
|
from marl_factory_grid.utils.states import Gamestate
|
||||||
|
|
||||||
REC_TAC = 'rec_'
|
|
||||||
|
|
||||||
|
|
||||||
class Factory(gym.Env):
|
class Factory(gym.Env):
|
||||||
|
|
||||||
@@ -44,11 +42,6 @@ class Factory(gym.Env):
|
|||||||
config_dict = yaml.safe_load(config_path.open())
|
config_dict = yaml.safe_load(config_path.open())
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
@property
|
|
||||||
def summarize_header(self):
|
|
||||||
summary_dict = self._summarize_state(stateless_entities=True)
|
|
||||||
return summary_dict
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
@@ -125,9 +118,6 @@ class Factory(gym.Env):
|
|||||||
info = reward_info
|
info = reward_info
|
||||||
|
|
||||||
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
info.update(step_reward=sum(reward), step=self.state.curr_step)
|
||||||
# TODO:
|
|
||||||
# if self._record_episodes:
|
|
||||||
# info.update(self._summarize_state())
|
|
||||||
|
|
||||||
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
|
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
|
||||||
info.update(reset_info)
|
info.update(reset_info)
|
||||||
@@ -171,14 +161,6 @@ class Factory(gym.Env):
|
|||||||
self.state.print(f"reward is {reward}")
|
self.state.print(f"reward is {reward}")
|
||||||
return reward, combined_info_dict, done
|
return reward, combined_info_dict, done
|
||||||
|
|
||||||
def start_recording(self):
|
|
||||||
self.conf.do_record = True
|
|
||||||
return self.conf.do_record
|
|
||||||
|
|
||||||
def stop_recording(self):
|
|
||||||
self.conf.do_record = False
|
|
||||||
return not self.conf.do_record
|
|
||||||
|
|
||||||
# noinspection PyGlobalUndefined
|
# noinspection PyGlobalUndefined
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
if not self._renderer: # lazy init
|
if not self._renderer: # lazy init
|
||||||
@@ -193,12 +175,23 @@ class Factory(gym.Env):
|
|||||||
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
|
||||||
return self._renderer.render(render_entities)
|
return self._renderer.render(render_entities)
|
||||||
|
|
||||||
def _summarize_state(self, stateless_entities=False):
|
def summarize_header(self):
|
||||||
summary = {f'{REC_TAC}step': self.state.curr_step}
|
header = {'rec_step': self.state.curr_step}
|
||||||
|
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
|
||||||
|
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
|
||||||
|
return header
|
||||||
|
|
||||||
for entity_group in self.state:
|
def summarize_state(self):
|
||||||
if entity_group.is_stateless == stateless_entities:
|
summary = {'step': self.state.curr_step}
|
||||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
|
||||||
|
# Todo: Protobuff Compatibility Section #######
|
||||||
|
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
|
||||||
|
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
|
||||||
|
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
|
||||||
|
# TODO Section End ########
|
||||||
|
for key in list(summary.keys()):
|
||||||
|
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
|
||||||
|
del summary[key]
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
|
|||||||
@@ -23,9 +23,6 @@ class EnvObjects(Objects):
|
|||||||
super(EnvObjects, self).add_item(item)
|
super(EnvObjects, self).add_item(item)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def summarize_states(self):
|
|
||||||
return [entity.summarize_state() for entity in self.values()]
|
|
||||||
|
|
||||||
def delete_env_object(self, env_object: EnvObject):
|
def delete_env_object(self, env_object: EnvObject):
|
||||||
del self[env_object.name]
|
del self[env_object.name]
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class PositionMixin:
|
|||||||
def by_pos(self, pos: (int, int)):
|
def by_pos(self, pos: (int, int)):
|
||||||
pos = tuple(pos)
|
pos = tuple(pos)
|
||||||
try:
|
try:
|
||||||
return next(e for e in self if e.pos == pos)
|
return self.pos_dict[pos]
|
||||||
|
# return next(e for e in self if e.pos == pos)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -144,7 +144,13 @@ class Objects:
|
|||||||
|
|
||||||
def notify_add_entity(self, entity: Object):
|
def notify_add_entity(self, entity: Object):
|
||||||
try:
|
try:
|
||||||
entity.add_observer(self)
|
if self not in entity.observers:
|
||||||
|
entity.add_observer(self)
|
||||||
self.pos_dict[entity.pos].append(entity)
|
self.pos_dict[entity.pos].append(entity)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def summarize_states(self):
|
||||||
|
# FIXME PROTOBUFF
|
||||||
|
# return [e.summarize_state() for e in self]
|
||||||
|
return [e.summarize_state() for e in self]
|
||||||
|
|||||||
@@ -43,38 +43,3 @@ class GlobalPositions(HasBoundMixin, EnvObjects):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(GlobalPositions, self).__init__(*args, **kwargs)
|
super(GlobalPositions, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ZonesOLD(Objects):
|
|
||||||
|
|
||||||
_entity = Zone
|
|
||||||
|
|
||||||
@property
|
|
||||||
def accounting_zones(self):
|
|
||||||
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
|
|
||||||
|
|
||||||
def __init__(self, parsed_level):
|
|
||||||
raise NotImplementedError('This needs a Rework')
|
|
||||||
super(Zones, self).__init__()
|
|
||||||
slices = list()
|
|
||||||
self._accounting_zones = list()
|
|
||||||
self._danger_zones = list()
|
|
||||||
for symbol in np.unique(parsed_level):
|
|
||||||
if symbol == c.VALUE_OCCUPIED_CELL:
|
|
||||||
continue
|
|
||||||
elif symbol == c.DANGER_ZONE:
|
|
||||||
self + symbol
|
|
||||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
|
||||||
self._danger_zones.append(symbol)
|
|
||||||
else:
|
|
||||||
self + symbol
|
|
||||||
slices.append(h.one_hot_level(parsed_level, symbol))
|
|
||||||
self._accounting_zones.append(symbol)
|
|
||||||
|
|
||||||
self._zone_slices = np.stack(slices)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self._zone_slices[item]
|
|
||||||
|
|
||||||
def add_items(self, other: Union[str, List[str]]):
|
|
||||||
raise AttributeError('You are not allowed to add additional Zones in runtime.')
|
|
||||||
|
|||||||
@@ -26,6 +26,12 @@ class Walls(PositionMixin, EnvObjects):
|
|||||||
def from_tiles(cls, tiles, *args, **kwargs):
|
def from_tiles(cls, tiles, *args, **kwargs):
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
def by_pos(self, pos: (int, int)):
|
||||||
|
try:
|
||||||
|
return super().by_pos(pos)[0]
|
||||||
|
except IndexError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Floors(Walls):
|
class Floors(Walls):
|
||||||
_entity = Floor
|
_entity = Floor
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
import warnings
|
|
||||||
from collections import defaultdict
|
|
||||||
from os import PathLike
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from gymnasium import Wrapper
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
from marl_factory_grid.environment.factory import REC_TAC
|
|
||||||
|
|
||||||
|
|
||||||
class EnvRecorder(Wrapper):
|
|
||||||
|
|
||||||
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
|
|
||||||
super(EnvRecorder, self).__init__(env)
|
|
||||||
self.filepath = filepath
|
|
||||||
self.freq = freq
|
|
||||||
self._recorder_dict = defaultdict(list)
|
|
||||||
self._recorder_out_list = list()
|
|
||||||
self._episode_counter = 1
|
|
||||||
self._do_record_dict = defaultdict(lambda: False)
|
|
||||||
if isinstance(entities, str):
|
|
||||||
if entities.lower() == 'all':
|
|
||||||
self._entities = None
|
|
||||||
else:
|
|
||||||
self._entities = [entities]
|
|
||||||
else:
|
|
||||||
self._entities = entities
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
return getattr(self.unwrapped, item)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._on_training_start()
|
|
||||||
return self.unwrapped.reset()
|
|
||||||
|
|
||||||
def _on_training_start(self) -> None:
|
|
||||||
assert self.start_recording()
|
|
||||||
|
|
||||||
def _read_info(self, env_idx, info: dict):
|
|
||||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
|
||||||
if self._entities:
|
|
||||||
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
|
|
||||||
self._recorder_dict[env_idx].append(info_dict)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _read_done(self, env_idx, done):
|
|
||||||
if done:
|
|
||||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
|
||||||
'episode': self._episode_counter})
|
|
||||||
self._recorder_dict[env_idx] = list()
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def step(self, actions):
|
|
||||||
step_result = self.unwrapped.step(actions)
|
|
||||||
if self.do_record_episode(0):
|
|
||||||
info = step_result[-1]
|
|
||||||
self._read_info(0, info)
|
|
||||||
if self._do_record_dict[0]:
|
|
||||||
self._read_done(0, step_result[-2])
|
|
||||||
return step_result
|
|
||||||
|
|
||||||
def finalize(self):
|
|
||||||
self._on_training_end()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def save_records(self, filepath: Union[Path, str, None] = None,
|
|
||||||
only_deltas=True,
|
|
||||||
save_occupation_map=False,
|
|
||||||
save_trajectory_map=False,
|
|
||||||
):
|
|
||||||
filepath = Path(filepath or self.filepath)
|
|
||||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
|
||||||
# cls.out_file.unlink(missing_ok=True)
|
|
||||||
with filepath.open('w') as f:
|
|
||||||
if only_deltas:
|
|
||||||
from deepdiff import DeepDiff
|
|
||||||
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
|
|
||||||
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
|
||||||
]
|
|
||||||
out_dict = {'episodes': diff_dict}
|
|
||||||
|
|
||||||
else:
|
|
||||||
out_dict = {'episodes': self._recorder_out_list}
|
|
||||||
out_dict.update(
|
|
||||||
{'n_episodes': self._episode_counter,
|
|
||||||
'env_params': self.env.params,
|
|
||||||
'header': self.env.summarize_header
|
|
||||||
})
|
|
||||||
try:
|
|
||||||
yaml.dump(out_dict, f, indent=4)
|
|
||||||
except TypeError:
|
|
||||||
print('Shit')
|
|
||||||
|
|
||||||
if save_occupation_map:
|
|
||||||
a = np.zeros((15, 15))
|
|
||||||
# noinspection PyTypeChecker
|
|
||||||
for episode in out_dict['episodes']:
|
|
||||||
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
|
||||||
|
|
||||||
b = list(df[['x', 'y']].to_records(index=False))
|
|
||||||
|
|
||||||
np.add.at(a, tuple(zip(*b)), 1)
|
|
||||||
|
|
||||||
# a = np.rot90(a)
|
|
||||||
import seaborn as sns
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
hm = sns.heatmap(data=a)
|
|
||||||
hm.set_title('Very Nice Heatmap')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if save_trajectory_map:
|
|
||||||
raise NotImplementedError('This has not yet been implemented.')
|
|
||||||
|
|
||||||
def do_record_episode(self, env_idx):
|
|
||||||
if not self._recorder_dict[env_idx]:
|
|
||||||
if self.freq:
|
|
||||||
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
|
|
||||||
else:
|
|
||||||
self._do_record_dict[env_idx] = False
|
|
||||||
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
|
|
||||||
'Nothing will be recorded')
|
|
||||||
self._episode_counter += 1
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
return self._do_record_dict[env_idx]
|
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
|
||||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
|
||||||
if self._do_record_dict[env_idx]:
|
|
||||||
self._read_info(env_idx, info)
|
|
||||||
dones = list(enumerate(self.locals.get('dones', [])))
|
|
||||||
dones.extend(list(enumerate(self.locals.get('done', []))))
|
|
||||||
for env_idx, done in dones:
|
|
||||||
if self._do_record_dict[env_idx]:
|
|
||||||
self._read_done(env_idx, done)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
|
||||||
for env_idx in range(len(self._recorder_dict)):
|
|
||||||
if self._recorder_dict[env_idx]:
|
|
||||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
|
|
||||||
'episode': self._episode_counter})
|
|
||||||
pass
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from .actions import BtryCharge
|
from .actions import BtryCharge
|
||||||
from .entitites import ChargePod, Battery
|
from .entitites import Pod, Battery
|
||||||
from .groups import ChargePods, Batteries
|
from .groups import ChargePods, Batteries
|
||||||
from .rules import BtryDoneAtDischarge, Btry
|
from .rules import BtryDoneAtDischarge, Btry
|
||||||
|
|||||||
@@ -8,12 +8,3 @@ CHARGE_POD_SYMBOL = 1
|
|||||||
|
|
||||||
|
|
||||||
CHARGE = 'do_charge_action'
|
CHARGE = 'do_charge_action'
|
||||||
|
|
||||||
|
|
||||||
class BatteryProperties(NamedTuple):
|
|
||||||
initial_charge: float = 0.8 #
|
|
||||||
charge_rate: float = 0.4 #
|
|
||||||
charge_locations: int = 20 #
|
|
||||||
per_action_costs: Union[dict, float] = 0.02
|
|
||||||
done_when_discharged: bool = False
|
|
||||||
multi_charge: bool = False
|
|
||||||
|
|||||||
@@ -42,16 +42,16 @@ class Battery(BoundEntityMixin, EnvObject):
|
|||||||
else:
|
else:
|
||||||
return c.NOT_VALID
|
return c.NOT_VALID
|
||||||
|
|
||||||
def summarize_state(self, **_):
|
def summarize_state(self):
|
||||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
summary = super().summarize_state()
|
||||||
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name))
|
summary.update(dict(belongs_to=self._bound_entity.name, chargeLevel=self.charge_level))
|
||||||
return attr_dict
|
return summary
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ChargePod(Entity):
|
class Pod(Entity):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
@@ -59,7 +59,7 @@ class ChargePod(Entity):
|
|||||||
|
|
||||||
def __init__(self, *args, charge_rate: float = 0.4,
|
def __init__(self, *args, charge_rate: float = 0.4,
|
||||||
multi_charge: bool = False, **kwargs):
|
multi_charge: bool = False, **kwargs):
|
||||||
super(ChargePod, self).__init__(*args, **kwargs)
|
super(Pod, self).__init__(*args, **kwargs)
|
||||||
self.charge_rate = charge_rate
|
self.charge_rate = charge_rate
|
||||||
self.multi_charge = multi_charge
|
self.multi_charge = multi_charge
|
||||||
|
|
||||||
@@ -73,3 +73,8 @@ class ChargePod(Entity):
|
|||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return RenderEntity(b.CHARGE_PODS, self.pos)
|
return RenderEntity(b.CHARGE_PODS, self.pos)
|
||||||
|
|
||||||
|
def summarize_state(self) -> dict:
|
||||||
|
summery = super().summarize_state()
|
||||||
|
summery.update(charge_rate=self.charge_rate)
|
||||||
|
return summery
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
from marl_factory_grid.environment.groups.env_objects import EnvObjects
|
||||||
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
|
||||||
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery
|
from marl_factory_grid.modules.batteries.entitites import Pod, Battery
|
||||||
|
|
||||||
|
|
||||||
class Batteries(HasBoundMixin, EnvObjects):
|
class Batteries(HasBoundMixin, EnvObjects):
|
||||||
@@ -20,9 +20,10 @@ class Batteries(HasBoundMixin, EnvObjects):
|
|||||||
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
|
||||||
self.add_items(batteries)
|
self.add_items(batteries)
|
||||||
|
|
||||||
|
|
||||||
class ChargePods(PositionMixin, EnvObjects):
|
class ChargePods(PositionMixin, EnvObjects):
|
||||||
|
|
||||||
_entity = ChargePod
|
_entity = Pod
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ChargePods, self).__init__(*args, **kwargs)
|
super(ChargePods, self).__init__(*args, **kwargs)
|
||||||
|
|||||||
@@ -59,3 +59,19 @@ class BtryDoneAtDischarge(Rule):
|
|||||||
else:
|
else:
|
||||||
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
|
||||||
|
|
||||||
|
|
||||||
|
class PodRules(Rule):
|
||||||
|
|
||||||
|
def __init__(self, n_pods: int, charge_rate: float = 0.4, multi_charge: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.multi_charge = multi_charge
|
||||||
|
self.charge_rate = charge_rate
|
||||||
|
self.n_pods = n_pods
|
||||||
|
|
||||||
|
def on_init(self, state, lvl_map):
|
||||||
|
pod_collection = state[b.CHARGE_PODS]
|
||||||
|
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_pods]
|
||||||
|
pods = pod_collection.from_tiles(empty_tiles, entity_kwargs=dict(
|
||||||
|
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
|
||||||
|
)
|
||||||
|
pod_collection.add_items(pods)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class Destination(Entity):
|
|||||||
def summarize_state(self) -> dict:
|
def summarize_state(self) -> dict:
|
||||||
state_summary = super().summarize_state()
|
state_summary = super().summarize_state()
|
||||||
state_summary.update(per_agent_times=[
|
state_summary.update(per_agent_times=[
|
||||||
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time)
|
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.items()], dwell_time=self.dwell_time)
|
||||||
return state_summary
|
return state_summary
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ class ItemAction(Action):
|
|||||||
reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL
|
reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL
|
||||||
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
|
||||||
|
|
||||||
elif item := state[i.ITEM].by_pos(entity.pos):
|
elif items := state[i.ITEM].by_pos(entity.pos):
|
||||||
|
item = items[0]
|
||||||
item.change_parent_collection(inventory)
|
item.change_parent_collection(inventory)
|
||||||
item.set_tile_to(state.NO_POS_TILE)
|
item.set_tile_to(state.NO_POS_TILE)
|
||||||
state.print(f'{entity.name} just picked up an item at {entity.pos}')
|
state.print(f'{entity.name} just picked up an item at {entity.pos}')
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class Inventories(HasBoundMixin, Objects):
|
|||||||
_entity = Inventory
|
_entity = Inventory
|
||||||
var_can_move = False
|
var_can_move = False
|
||||||
|
|
||||||
def __init__(self, size, *args, **kwargs):
|
def __init__(self, size: int, *args, **kwargs):
|
||||||
super(Inventories, self).__init__(*args, **kwargs)
|
super(Inventories, self).__init__(*args, **kwargs)
|
||||||
self.size = size
|
self.size = size
|
||||||
self._obs = None
|
self._obs = None
|
||||||
|
|||||||
@@ -17,3 +17,4 @@ def init():
|
|||||||
shutil.copytree(template_path, cwd)
|
shutil.copytree(template_path, cwd)
|
||||||
print(f'Templates copied to {cwd}"/"{template_path.name}')
|
print(f'Templates copied to {cwd}"/"{template_path.name}')
|
||||||
print(':wave:')
|
print(':wave:')
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ class FactoryConfigParser(object):
|
|||||||
self.config_path = Path(config_path)
|
self.config_path = Path(config_path)
|
||||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||||
self.config = yaml.safe_load(self.config_path.open())
|
self.config = yaml.safe_load(self.config_path.open())
|
||||||
self.do_record = False
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return self['General'][item]
|
return self['General'][item]
|
||||||
|
|||||||
+4
-7
@@ -6,11 +6,10 @@ from typing import Union
|
|||||||
from gymnasium import Wrapper
|
from gymnasium import Wrapper
|
||||||
|
|
||||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||||
from marl_factory_grid.environment.factory import REC_TAC
|
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from marl_factory_grid.plotting.compare_runs import plot_single_run
|
from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
|
||||||
|
|
||||||
|
|
||||||
class EnvMonitor(Wrapper):
|
class EnvMonitor(Wrapper):
|
||||||
@@ -23,8 +22,6 @@ class EnvMonitor(Wrapper):
|
|||||||
self._monitor_df = pd.DataFrame()
|
self._monitor_df = pd.DataFrame()
|
||||||
self._monitor_dict = dict()
|
self._monitor_dict = dict()
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
return getattr(self.unwrapped, item)
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs_type, obs, reward, done, info = self.env.step(action)
|
obs_type, obs, reward, done, info = self.env.step(action)
|
||||||
@@ -33,12 +30,12 @@ class EnvMonitor(Wrapper):
|
|||||||
return obs_type, obs, reward, done, info
|
return obs_type, obs, reward, done, info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self.unwrapped.reset()
|
return self.env.reset()
|
||||||
|
|
||||||
def _read_info(self, info: dict):
|
def _read_info(self, info: dict):
|
||||||
self._monitor_dict[len(self._monitor_dict)] = {
|
self._monitor_dict[len(self._monitor_dict)] = {
|
||||||
key: val for key, val in info.items() if
|
key: val for key, val in info.items() if
|
||||||
key not in ['terminal_observation', 'episode'] and not key.startswith(REC_TAC)}
|
key not in ['terminal_observation', 'episode']}
|
||||||
return
|
return
|
||||||
|
|
||||||
def _read_done(self, done):
|
def _read_done(self, done):
|
||||||
@@ -50,7 +47,7 @@ class EnvMonitor(Wrapper):
|
|||||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||||
)
|
)
|
||||||
env_monitor_df['episode'] = len(self._monitor_df)
|
env_monitor_df['episode'] = len(self._monitor_df)
|
||||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
self._monitor_df = pd.concat([self._monitor_df, pd.DataFrame([env_monitor_df])], ignore_index=True)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
return
|
return
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
from os import PathLike
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from gymnasium import Wrapper
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class EnvRecorder(Wrapper):
|
||||||
|
|
||||||
|
def __init__(self, env, filepath: Union[str, PathLike] = None,
|
||||||
|
episodes: Union[List[int], None] = None):
|
||||||
|
super(EnvRecorder, self).__init__(env)
|
||||||
|
self.filepath = filepath
|
||||||
|
self.episodes = episodes
|
||||||
|
self._curr_episode = 0
|
||||||
|
self._curr_ep_recorder = list()
|
||||||
|
self._recorder_out_list = list()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._curr_ep_recorder = list()
|
||||||
|
self._recorder_out_list = list()
|
||||||
|
self._curr_episode += 1
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||||
|
if not self.episodes or self._curr_episode in self.episodes:
|
||||||
|
summary: dict = self.env.summarize_state()
|
||||||
|
# summary.update(done=done)
|
||||||
|
# summary.update({'episode': self._curr_episode})
|
||||||
|
# TODO Protobuff Adjustments ######
|
||||||
|
# summary.update(info)
|
||||||
|
self._curr_ep_recorder.append(summary)
|
||||||
|
if done:
|
||||||
|
self._recorder_out_list.append({'steps': self._curr_ep_recorder,
|
||||||
|
'episode_nr': self._curr_episode})
|
||||||
|
self._curr_ep_recorder = list()
|
||||||
|
return obs_type, obs, reward, done, info
|
||||||
|
|
||||||
|
def _finalize(self):
|
||||||
|
if self._curr_ep_recorder:
|
||||||
|
self._recorder_out_list.append({'steps': self._curr_ep_recorder.copy(),
|
||||||
|
'episode_nr': len(self._recorder_out_list)})
|
||||||
|
|
||||||
|
def save_records(self, filepath: Union[Path, str, None] = None,
|
||||||
|
only_deltas=False,
|
||||||
|
save_occupation_map=False,
|
||||||
|
save_trajectory_map=False,
|
||||||
|
):
|
||||||
|
self._finalize()
|
||||||
|
filepath = Path(filepath or self.filepath)
|
||||||
|
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
# cls.out_file.unlink(missing_ok=True)
|
||||||
|
with filepath.open('wb') as f:
|
||||||
|
if only_deltas:
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
diff_dict = [DeepDiff(t1, t2, ignore_order=True)
|
||||||
|
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
|
||||||
|
]
|
||||||
|
out_dict = {'episodes': diff_dict}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# TODO Protobuff Adjustments Revert
|
||||||
|
dest_prop = dict(
|
||||||
|
n_dests=0,
|
||||||
|
dwell_time=0,
|
||||||
|
spawn_frequency=0,
|
||||||
|
spawn_in_other_zone=False,
|
||||||
|
spawn_mode=''
|
||||||
|
)
|
||||||
|
rewards_dest = dict(
|
||||||
|
WAIT_VALID=0.00,
|
||||||
|
WAIT_FAIL=0.00,
|
||||||
|
DEST_REACHED=0.00,
|
||||||
|
)
|
||||||
|
mv_prop = dict(
|
||||||
|
allow_square_movement=False,
|
||||||
|
allow_diagonal_movement=False,
|
||||||
|
allow_no_op=False,
|
||||||
|
)
|
||||||
|
obs_prop = dict(
|
||||||
|
render_agents='',
|
||||||
|
omit_agent_self=False,
|
||||||
|
additional_agent_placeholder=0,
|
||||||
|
cast_shadows=False,
|
||||||
|
frames_to_stack=0,
|
||||||
|
pomdp_r=self.env.params['General']['pomdp_r'],
|
||||||
|
indicate_door_area=False,
|
||||||
|
show_global_position_info=False,
|
||||||
|
|
||||||
|
)
|
||||||
|
rewards_base = dict(
|
||||||
|
MOVEMENTS_VALID=0.00,
|
||||||
|
MOVEMENTS_FAIL=0.00,
|
||||||
|
NOOP=0.00,
|
||||||
|
USE_DOOR_VALID=0.00,
|
||||||
|
USE_DOOR_FAIL=0.00,
|
||||||
|
COLLISION=0.00,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dict = {'episodes': self._recorder_out_list}
|
||||||
|
out_dict.update(
|
||||||
|
{'n_episodes': self._curr_episode,
|
||||||
|
'metadata':dict(
|
||||||
|
level_name=self.env.params['General']['level_name'],
|
||||||
|
verbose=False,
|
||||||
|
n_agents=len(self.env.params['Agents']),
|
||||||
|
max_steps=100,
|
||||||
|
done_at_collision=False,
|
||||||
|
parse_doors=True,
|
||||||
|
doors_have_area=False,
|
||||||
|
individual_rewards=True,
|
||||||
|
class_name='Where does this end up?',
|
||||||
|
env_seed=69,
|
||||||
|
|
||||||
|
dest_prop=dest_prop,
|
||||||
|
rewards_dest=rewards_dest,
|
||||||
|
mv_prop=mv_prop,
|
||||||
|
obs_prop=obs_prop,
|
||||||
|
rewards_base=rewards_base,
|
||||||
|
),
|
||||||
|
# 'env_params': self.env.params,
|
||||||
|
'header': self.env.summarize_header()
|
||||||
|
})
|
||||||
|
try:
|
||||||
|
from marl_factory_grid.utils.proto import fiksProto_pb2
|
||||||
|
from google.protobuf import json_format
|
||||||
|
|
||||||
|
bulk = fiksProto_pb2.Bulk()
|
||||||
|
json_format.ParseDict(out_dict, bulk)
|
||||||
|
f.write(bulk.SerializeToString())
|
||||||
|
# yaml.dump(out_dict, f, indent=4)
|
||||||
|
except TypeError:
|
||||||
|
print('Shit')
|
||||||
|
print('done')
|
||||||
|
|
||||||
|
if save_occupation_map:
|
||||||
|
a = np.zeros((15, 15))
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
for episode in out_dict['episodes']:
|
||||||
|
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
|
||||||
|
|
||||||
|
b = list(df[['x', 'y']].to_records(index=False))
|
||||||
|
|
||||||
|
np.add.at(a, tuple(zip(*b)), 1)
|
||||||
|
|
||||||
|
# a = np.rot90(a)
|
||||||
|
import seaborn as sns
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
hm = sns.heatmap(data=a)
|
||||||
|
hm.set_title('Very Nice Heatmap')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if save_trajectory_map:
|
||||||
|
raise NotImplementedError('This has not yet been implemented.')
|
||||||
+1
-1
@@ -7,7 +7,7 @@ from typing import Union, List
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||||
from marl_factory_grid.plotting.plotting import prepare_plot
|
from marl_factory_grid.utils.plotting.plotting import prepare_plot
|
||||||
|
|
||||||
MODEL_MAP = None
|
MODEL_MAP = None
|
||||||
|
|
||||||
@@ -120,6 +120,7 @@ class ConfigExplainer:
|
|||||||
|
|
||||||
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
|
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
|
||||||
filepath = Path(filepath)
|
filepath = Path(filepath)
|
||||||
|
yaml.Dumper.ignore_aliases = lambda *args: True
|
||||||
with filepath.open('w') as f:
|
with filepath.open('w') as f:
|
||||||
yaml.dump(data, f, encoding='utf-8')
|
yaml.dump(data, f, encoding='utf-8')
|
||||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||||
|
|||||||
+2
-2
@@ -4,8 +4,8 @@ from pathlib import Path
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from marl_factory_grid.environment.factory import Factory
|
from marl_factory_grid.environment.factory import Factory
|
||||||
from marl_factory_grid.logging.envmonitor import EnvMonitor
|
from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
|
||||||
from marl_factory_grid.logging.recorder import EnvRecorder
|
from marl_factory_grid.utils.logging.recorder import EnvRecorder
|
||||||
|
|
||||||
from marl_factory_grid.modules.doors import constants as d
|
from marl_factory_grid.modules.doors import constants as d
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user