1
0
mirror of https://github.com/illiumst/marl-factory-grid.git synced 2026-06-17 23:22:16 +02:00

New Szenario "Two_Rooms_One_Door"

This commit is contained in:
Steffen Illium
2023-09-01 13:04:54 +02:00
parent fb0066d800
commit 714e07a816
29 changed files with 271 additions and 277 deletions
@@ -21,7 +21,6 @@ Agents:
- Items - Items
- Inventory - Inventory
- DropOffLocations - DropOffLocations
- Machines
- Maintainers - Maintainers
Entities: Entities:
Batteries: {} Batteries: {}
@@ -79,10 +78,6 @@ Rules:
n_items: 5 n_items: 5
n_locations: 5 n_locations: 5
spawn_frequency: 15 spawn_frequency: 15
MachineRule:
n_machines: 2
MaintenanceRule:
n_maintainer: 1
MaxStepsReached: MaxStepsReached:
max_steps: 500 max_steps: 500
# AgentSingleZonePlacement: # AgentSingleZonePlacement:
@@ -1,3 +1,31 @@
General:
env_seed: 69
individual_rewards: true
level_name: two_rooms
pomdp_r: 3
verbose: false
Entities:
BoundDestinations: {}
ReachedDestinations: {}
Doors: {}
GlobalPositions: {}
Zones: {}
Rules:
# Init:
AssignGlobalPositions: {}
ZoneInit: {}
AgentSingleZonePlacement: {}
IndividualDestinationZonePlacement: {}
# Env Rules
MaxStepsReached:
max_steps: 10
Collision:
done_at_collisions: false
DoorAutoClose:
close_frequency: 10
Agents: Agents:
Wolfgang: Wolfgang:
Actions: Actions:
@@ -21,27 +49,4 @@ Agents:
- Other - Other
- Walls - Walls
- BoundDestination - BoundDestination
- Doors - Doors
Entities:
BoundDestinations: {}
ReachedDestinations: {}
Doors: {}
GlobalPositions: {}
Zones: {}
General:
env_seed: 69
individual_rewards: true
level_name: two_rooms
pomdp_r: 3
verbose: false
Rules:
Collision:
done_at_collisions: false
AssignGlobalPositions: {}
DoorAutoClose:
close_frequency: 10
ZoneInit: {}
AgentSingleZonePlacement: {}
IndividualDestinationZonePlacement: {}
@@ -62,11 +62,14 @@ class Object:
def add_observer(self, observer): def add_observer(self, observer):
self.observers.append(observer) self.observers.append(observer)
observer.notify_change_pos(self) observer.notify_add_entity(self)
def del_observer(self, observer): def del_observer(self, observer):
self.observers.remove(observer) self.observers.remove(observer)
def summarize_state(self):
return dict()
class EnvObject(Object): class EnvObject(Object):
@@ -128,3 +131,6 @@ class EnvObject(Object):
self._collection.delete_env_object(self) self._collection.delete_env_object(self)
self._collection = other_collection self._collection = other_collection
return self._collection == other_collection return self._collection == other_collection
def summarize_state(self):
return dict(name=str(self.name))
+16 -23
View File
@@ -16,8 +16,6 @@ import marl_factory_grid.environment.constants as c
from marl_factory_grid.utils.states import Gamestate from marl_factory_grid.utils.states import Gamestate
REC_TAC = 'rec_'
class Factory(gym.Env): class Factory(gym.Env):
@@ -44,11 +42,6 @@ class Factory(gym.Env):
config_dict = yaml.safe_load(config_path.open()) config_dict = yaml.safe_load(config_path.open())
return config_dict return config_dict
@property
def summarize_header(self):
summary_dict = self._summarize_state(stateless_entities=True)
return summary_dict
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.close() self.close()
@@ -125,9 +118,6 @@ class Factory(gym.Env):
info = reward_info info = reward_info
info.update(step_reward=sum(reward), step=self.state.curr_step) info.update(step_reward=sum(reward), step=self.state.curr_step)
# TODO:
# if self._record_episodes:
# info.update(self._summarize_state())
obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state) obs, reset_info = self.obs_builder.refresh_and_build_for_all(self.state)
info.update(reset_info) info.update(reset_info)
@@ -171,14 +161,6 @@ class Factory(gym.Env):
self.state.print(f"reward is {reward}") self.state.print(f"reward is {reward}")
return reward, combined_info_dict, done return reward, combined_info_dict, done
def start_recording(self):
self.conf.do_record = True
return self.conf.do_record
def stop_recording(self):
self.conf.do_record = False
return not self.conf.do_record
# noinspection PyGlobalUndefined # noinspection PyGlobalUndefined
def render(self, mode='human'): def render(self, mode='human'):
if not self._renderer: # lazy init if not self._renderer: # lazy init
@@ -193,12 +175,23 @@ class Factory(gym.Env):
render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name] render_entity.aux = self.obs_builder.curr_lightmaps[render_entity.real_name]
return self._renderer.render(render_entities) return self._renderer.render(render_entities)
def _summarize_state(self, stateless_entities=False): def summarize_header(self):
summary = {f'{REC_TAC}step': self.state.curr_step} header = {'rec_step': self.state.curr_step}
for entity_group in (x for x in self.state if x.name in ['Walls', 'Floors', 'DropOffLocations', 'ChargePods']):
header.update({f'rec{entity_group.name}': entity_group.summarize_states()})
return header
for entity_group in self.state: def summarize_state(self):
if entity_group.is_stateless == stateless_entities: summary = {'step': self.state.curr_step}
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
# Todo: Protobuff Compatibility Section #######
# for entity_group in (x for x in self.state if x.name not in [c.WALLS, c.FLOORS]):
for entity_group in (x for x in self.state if x.name not in [c.FLOORS]):
summary.update({entity_group.name.lower(): entity_group.summarize_states()})
# TODO Section End ########
for key in list(summary.keys()):
if key not in ['step', 'walls', 'doors', 'agents', 'items', 'dirtPiles', 'batteries']:
del summary[key]
return summary return summary
def print(self, string): def print(self, string):
@@ -23,9 +23,6 @@ class EnvObjects(Objects):
super(EnvObjects, self).add_item(item) super(EnvObjects, self).add_item(item)
return self return self
def summarize_states(self):
return [entity.summarize_state() for entity in self.values()]
def delete_env_object(self, env_object: EnvObject): def delete_env_object(self, env_object: EnvObject):
del self[env_object.name] del self[env_object.name]
@@ -45,7 +45,8 @@ class PositionMixin:
def by_pos(self, pos: (int, int)): def by_pos(self, pos: (int, int)):
pos = tuple(pos) pos = tuple(pos)
try: try:
return next(e for e in self if e.pos == pos) return self.pos_dict[pos]
# return next(e for e in self if e.pos == pos)
except StopIteration: except StopIteration:
pass pass
except ValueError: except ValueError:
@@ -144,7 +144,13 @@ class Objects:
def notify_add_entity(self, entity: Object): def notify_add_entity(self, entity: Object):
try: try:
entity.add_observer(self) if self not in entity.observers:
entity.add_observer(self)
self.pos_dict[entity.pos].append(entity) self.pos_dict[entity.pos].append(entity)
except (ValueError, AttributeError): except (ValueError, AttributeError):
pass pass
def summarize_states(self):
# FIXME PROTOBUFF
# return [e.summarize_state() for e in self]
return [e.summarize_state() for e in self]
@@ -43,38 +43,3 @@ class GlobalPositions(HasBoundMixin, EnvObjects):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(GlobalPositions, self).__init__(*args, **kwargs) super(GlobalPositions, self).__init__(*args, **kwargs)
class ZonesOLD(Objects):
_entity = Zone
@property
def accounting_zones(self):
return [self[idx] for idx, name in self.items() if name != c.DANGER_ZONE]
def __init__(self, parsed_level):
raise NotImplementedError('This needs a Rework')
super(Zones, self).__init__()
slices = list()
self._accounting_zones = list()
self._danger_zones = list()
for symbol in np.unique(parsed_level):
if symbol == c.VALUE_OCCUPIED_CELL:
continue
elif symbol == c.DANGER_ZONE:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._danger_zones.append(symbol)
else:
self + symbol
slices.append(h.one_hot_level(parsed_level, symbol))
self._accounting_zones.append(symbol)
self._zone_slices = np.stack(slices)
def __getitem__(self, item):
return self._zone_slices[item]
def add_items(self, other: Union[str, List[str]]):
raise AttributeError('You are not allowed to add additional Zones in runtime.')
@@ -26,6 +26,12 @@ class Walls(PositionMixin, EnvObjects):
def from_tiles(cls, tiles, *args, **kwargs): def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError() raise RuntimeError()
def by_pos(self, pos: (int, int)):
try:
return super().by_pos(pos)[0]
except IndexError:
return None
class Floors(Walls): class Floors(Walls):
_entity = Floor _entity = Floor
-152
View File
@@ -1,152 +0,0 @@
import warnings
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
from marl_factory_grid.environment.factory import REC_TAC
class EnvRecorder(Wrapper):
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
super(EnvRecorder, self).__init__(env)
self.filepath = filepath
self.freq = freq
self._recorder_dict = defaultdict(list)
self._recorder_out_list = list()
self._episode_counter = 1
self._do_record_dict = defaultdict(lambda: False)
if isinstance(entities, str):
if entities.lower() == 'all':
self._entities = None
else:
self._entities = [entities]
else:
self._entities = entities
def __getattr__(self, item):
return getattr(self.unwrapped, item)
def reset(self):
self._on_training_start()
return self.unwrapped.reset()
def _on_training_start(self) -> None:
assert self.start_recording()
def _read_info(self, env_idx, info: dict):
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
if self._entities:
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
self._recorder_dict[env_idx].append(info_dict)
else:
pass
return True
def _read_done(self, env_idx, done):
if done:
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
'episode': self._episode_counter})
self._recorder_dict[env_idx] = list()
else:
pass
def step(self, actions):
step_result = self.unwrapped.step(actions)
if self.do_record_episode(0):
info = step_result[-1]
self._read_info(0, info)
if self._do_record_dict[0]:
self._read_done(0, step_result[-2])
return step_result
def finalize(self):
self._on_training_end()
return True
def save_records(self, filepath: Union[Path, str, None] = None,
only_deltas=True,
save_occupation_map=False,
save_trajectory_map=False,
):
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
# cls.out_file.unlink(missing_ok=True)
with filepath.open('w') as f:
if only_deltas:
from deepdiff import DeepDiff
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
]
out_dict = {'episodes': diff_dict}
else:
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._episode_counter,
'env_params': self.env.params,
'header': self.env.summarize_header
})
try:
yaml.dump(out_dict, f, indent=4)
except TypeError:
print('Shit')
if save_occupation_map:
a = np.zeros((15, 15))
# noinspection PyTypeChecker
for episode in out_dict['episodes']:
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
b = list(df[['x', 'y']].to_records(index=False))
np.add.at(a, tuple(zip(*b)), 1)
# a = np.rot90(a)
import seaborn as sns
from matplotlib import pyplot as plt
hm = sns.heatmap(data=a)
hm.set_title('Very Nice Heatmap')
plt.show()
if save_trajectory_map:
raise NotImplementedError('This has not yet been implemented.')
def do_record_episode(self, env_idx):
if not self._recorder_dict[env_idx]:
if self.freq:
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
else:
self._do_record_dict[env_idx] = False
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
'Nothing will be recorded')
self._episode_counter += 1
else:
pass
return self._do_record_dict[env_idx]
def _on_step(self) -> bool:
for env_idx, info in enumerate(self.locals.get('infos', [])):
if self._do_record_dict[env_idx]:
self._read_info(env_idx, info)
dones = list(enumerate(self.locals.get('dones', [])))
dones.extend(list(enumerate(self.locals.get('done', []))))
for env_idx, done in dones:
if self._do_record_dict[env_idx]:
self._read_done(env_idx, done)
return True
def _on_training_end(self) -> None:
for env_idx in range(len(self._recorder_dict)):
if self._recorder_dict[env_idx]:
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
'episode': self._episode_counter})
pass
@@ -1,4 +1,4 @@
from .actions import BtryCharge from .actions import BtryCharge
from .entitites import ChargePod, Battery from .entitites import Pod, Battery
from .groups import ChargePods, Batteries from .groups import ChargePods, Batteries
from .rules import BtryDoneAtDischarge, Btry from .rules import BtryDoneAtDischarge, Btry
@@ -8,12 +8,3 @@ CHARGE_POD_SYMBOL = 1
CHARGE = 'do_charge_action' CHARGE = 'do_charge_action'
class BatteryProperties(NamedTuple):
initial_charge: float = 0.8 #
charge_rate: float = 0.4 #
charge_locations: int = 20 #
per_action_costs: Union[dict, float] = 0.02
done_when_discharged: bool = False
multi_charge: bool = False
@@ -42,16 +42,16 @@ class Battery(BoundEntityMixin, EnvObject):
else: else:
return c.NOT_VALID return c.NOT_VALID
def summarize_state(self, **_): def summarize_state(self):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'} summary = super().summarize_state()
attr_dict.update(dict(name=self.name, belongs_to=self._bound_entity.name)) summary.update(dict(belongs_to=self._bound_entity.name, chargeLevel=self.charge_level))
return attr_dict return summary
def render(self): def render(self):
return None return None
class ChargePod(Entity): class Pod(Entity):
@property @property
def encoding(self): def encoding(self):
@@ -59,7 +59,7 @@ class ChargePod(Entity):
def __init__(self, *args, charge_rate: float = 0.4, def __init__(self, *args, charge_rate: float = 0.4,
multi_charge: bool = False, **kwargs): multi_charge: bool = False, **kwargs):
super(ChargePod, self).__init__(*args, **kwargs) super(Pod, self).__init__(*args, **kwargs)
self.charge_rate = charge_rate self.charge_rate = charge_rate
self.multi_charge = multi_charge self.multi_charge = multi_charge
@@ -73,3 +73,8 @@ class ChargePod(Entity):
def render(self): def render(self):
return RenderEntity(b.CHARGE_PODS, self.pos) return RenderEntity(b.CHARGE_PODS, self.pos)
def summarize_state(self) -> dict:
summery = super().summarize_state()
summery.update(charge_rate=self.charge_rate)
return summery
@@ -1,6 +1,6 @@
from marl_factory_grid.environment.groups.env_objects import EnvObjects from marl_factory_grid.environment.groups.env_objects import EnvObjects
from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin from marl_factory_grid.environment.groups.mixins import PositionMixin, HasBoundMixin
from marl_factory_grid.modules.batteries.entitites import ChargePod, Battery from marl_factory_grid.modules.batteries.entitites import Pod, Battery
class Batteries(HasBoundMixin, EnvObjects): class Batteries(HasBoundMixin, EnvObjects):
@@ -20,9 +20,10 @@ class Batteries(HasBoundMixin, EnvObjects):
batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)] batteries = [self._entity(initial_charge_level, agent) for _, agent in enumerate(agents)]
self.add_items(batteries) self.add_items(batteries)
class ChargePods(PositionMixin, EnvObjects): class ChargePods(PositionMixin, EnvObjects):
_entity = ChargePod _entity = Pod
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ChargePods, self).__init__(*args, **kwargs) super(ChargePods, self).__init__(*args, **kwargs)
@@ -59,3 +59,19 @@ class BtryDoneAtDischarge(Rule):
else: else:
return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)] return [DoneResult(self.name, validity=c.NOT_VALID, reward=0)]
class PodRules(Rule):
def __init__(self, n_pods: int, charge_rate: float = 0.4, multi_charge: bool = False):
super().__init__()
self.multi_charge = multi_charge
self.charge_rate = charge_rate
self.n_pods = n_pods
def on_init(self, state, lvl_map):
pod_collection = state[b.CHARGE_PODS]
empty_tiles = state[c.FLOOR].empty_tiles[:self.n_pods]
pods = pod_collection.from_tiles(empty_tiles, entity_kwargs=dict(
multi_charge=self.multi_charge, charge_rate=self.charge_rate)
)
pod_collection.add_items(pods)
@@ -51,7 +51,7 @@ class Destination(Entity):
def summarize_state(self) -> dict: def summarize_state(self) -> dict:
state_summary = super().summarize_state() state_summary = super().summarize_state()
state_summary.update(per_agent_times=[ state_summary.update(per_agent_times=[
dict(belongs_to=key, time=val) for key, val in self._per_agent_times.keys()], dwell_time=self.dwell_time) dict(belongs_to=key, time=val) for key, val in self._per_agent_times.items()], dwell_time=self.dwell_time)
return state_summary return state_summary
def render(self): def render(self):
+2 -1
View File
@@ -26,7 +26,8 @@ class ItemAction(Action):
reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL reward = r.DROP_OFF_VALID if valid else r.DROP_OFF_FAIL
return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward) return ActionResult(entity=entity, identifier=self._identifier, validity=valid, reward=reward)
elif item := state[i.ITEM].by_pos(entity.pos): elif items := state[i.ITEM].by_pos(entity.pos):
item = items[0]
item.change_parent_collection(inventory) item.change_parent_collection(inventory)
item.set_tile_to(state.NO_POS_TILE) item.set_tile_to(state.NO_POS_TILE)
state.print(f'{entity.name} just picked up an item at {entity.pos}') state.print(f'{entity.name} just picked up an item at {entity.pos}')
+1 -1
View File
@@ -51,7 +51,7 @@ class Inventories(HasBoundMixin, Objects):
_entity = Inventory _entity = Inventory
var_can_move = False var_can_move = False
def __init__(self, size, *args, **kwargs): def __init__(self, size: int, *args, **kwargs):
super(Inventories, self).__init__(*args, **kwargs) super(Inventories, self).__init__(*args, **kwargs)
self.size = size self.size = size
self._obs = None self._obs = None
+1
View File
@@ -17,3 +17,4 @@ def init():
shutil.copytree(template_path, cwd) shutil.copytree(template_path, cwd)
print(f'Templates copied to {cwd}"/"{template_path.name}') print(f'Templates copied to {cwd}"/"{template_path.name}')
print(':wave:') print(':wave:')
-1
View File
@@ -24,7 +24,6 @@ class FactoryConfigParser(object):
self.config_path = Path(config_path) self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open()) self.config = yaml.safe_load(self.config_path.open())
self.do_record = False
def __getattr__(self, item): def __getattr__(self, item):
return self['General'][item] return self['General'][item]
@@ -6,11 +6,10 @@ from typing import Union
from gymnasium import Wrapper from gymnasium import Wrapper
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.environment.factory import REC_TAC
import pandas as pd import pandas as pd
from marl_factory_grid.plotting.compare_runs import plot_single_run from marl_factory_grid.utils.plotting.compare_runs import plot_single_run
class EnvMonitor(Wrapper): class EnvMonitor(Wrapper):
@@ -23,8 +22,6 @@ class EnvMonitor(Wrapper):
self._monitor_df = pd.DataFrame() self._monitor_df = pd.DataFrame()
self._monitor_dict = dict() self._monitor_dict = dict()
def __getattr__(self, item):
return getattr(self.unwrapped, item)
def step(self, action): def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action) obs_type, obs, reward, done, info = self.env.step(action)
@@ -33,12 +30,12 @@ class EnvMonitor(Wrapper):
return obs_type, obs, reward, done, info return obs_type, obs, reward, done, info
def reset(self): def reset(self):
return self.unwrapped.reset() return self.env.reset()
def _read_info(self, info: dict): def _read_info(self, info: dict):
self._monitor_dict[len(self._monitor_dict)] = { self._monitor_dict[len(self._monitor_dict)] = {
key: val for key, val in info.items() if key: val for key, val in info.items() if
key not in ['terminal_observation', 'episode'] and not key.startswith(REC_TAC)} key not in ['terminal_observation', 'episode']}
return return
def _read_done(self, done): def _read_done(self, done):
@@ -50,7 +47,7 @@ class EnvMonitor(Wrapper):
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns} {col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
) )
env_monitor_df['episode'] = len(self._monitor_df) env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df]) self._monitor_df = pd.concat([self._monitor_df, pd.DataFrame([env_monitor_df])], ignore_index=True)
else: else:
pass pass
return return
+160
View File
@@ -0,0 +1,160 @@
from os import PathLike
from pathlib import Path
from typing import Union, List
import yaml
from gymnasium import Wrapper
import numpy as np
import pandas as pd
class EnvRecorder(Wrapper):
def __init__(self, env, filepath: Union[str, PathLike] = None,
episodes: Union[List[int], None] = None):
super(EnvRecorder, self).__init__(env)
self.filepath = filepath
self.episodes = episodes
self._curr_episode = 0
self._curr_ep_recorder = list()
self._recorder_out_list = list()
def reset(self):
self._curr_ep_recorder = list()
self._recorder_out_list = list()
self._curr_episode += 1
return self.env.reset()
def step(self, actions):
obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes:
summary: dict = self.env.summarize_state()
# summary.update(done=done)
# summary.update({'episode': self._curr_episode})
# TODO Protobuff Adjustments ######
# summary.update(info)
self._curr_ep_recorder.append(summary)
if done:
self._recorder_out_list.append({'steps': self._curr_ep_recorder,
'episode_nr': self._curr_episode})
self._curr_ep_recorder = list()
return obs_type, obs, reward, done, info
def _finalize(self):
if self._curr_ep_recorder:
self._recorder_out_list.append({'steps': self._curr_ep_recorder.copy(),
'episode_nr': len(self._recorder_out_list)})
def save_records(self, filepath: Union[Path, str, None] = None,
only_deltas=False,
save_occupation_map=False,
save_trajectory_map=False,
):
self._finalize()
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
# cls.out_file.unlink(missing_ok=True)
with filepath.open('wb') as f:
if only_deltas:
from deepdiff import DeepDiff
diff_dict = [DeepDiff(t1, t2, ignore_order=True)
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
]
out_dict = {'episodes': diff_dict}
else:
# TODO Protobuff Adjustments Revert
dest_prop = dict(
n_dests=0,
dwell_time=0,
spawn_frequency=0,
spawn_in_other_zone=False,
spawn_mode=''
)
rewards_dest = dict(
WAIT_VALID=0.00,
WAIT_FAIL=0.00,
DEST_REACHED=0.00,
)
mv_prop = dict(
allow_square_movement=False,
allow_diagonal_movement=False,
allow_no_op=False,
)
obs_prop = dict(
render_agents='',
omit_agent_self=False,
additional_agent_placeholder=0,
cast_shadows=False,
frames_to_stack=0,
pomdp_r=self.env.params['General']['pomdp_r'],
indicate_door_area=False,
show_global_position_info=False,
)
rewards_base = dict(
MOVEMENTS_VALID=0.00,
MOVEMENTS_FAIL=0.00,
NOOP=0.00,
USE_DOOR_VALID=0.00,
USE_DOOR_FAIL=0.00,
COLLISION=0.00,
)
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._curr_episode,
'metadata':dict(
level_name=self.env.params['General']['level_name'],
verbose=False,
n_agents=len(self.env.params['Agents']),
max_steps=100,
done_at_collision=False,
parse_doors=True,
doors_have_area=False,
individual_rewards=True,
class_name='Where does this end up?',
env_seed=69,
dest_prop=dest_prop,
rewards_dest=rewards_dest,
mv_prop=mv_prop,
obs_prop=obs_prop,
rewards_base=rewards_base,
),
# 'env_params': self.env.params,
'header': self.env.summarize_header()
})
try:
from marl_factory_grid.utils.proto import fiksProto_pb2
from google.protobuf import json_format
bulk = fiksProto_pb2.Bulk()
json_format.ParseDict(out_dict, bulk)
f.write(bulk.SerializeToString())
# yaml.dump(out_dict, f, indent=4)
except TypeError:
print('Shit')
print('done')
if save_occupation_map:
a = np.zeros((15, 15))
# noinspection PyTypeChecker
for episode in out_dict['episodes']:
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
b = list(df[['x', 'y']].to_records(index=False))
np.add.at(a, tuple(zip(*b)), 1)
# a = np.rot90(a)
import seaborn as sns
from matplotlib import pyplot as plt
hm = sns.heatmap(data=a)
hm.set_title('Very Nice Heatmap')
plt.show()
if save_trajectory_map:
raise NotImplementedError('This has not yet been implemented.')
@@ -7,7 +7,7 @@ from typing import Union, List
import pandas as pd import pandas as pd
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
from marl_factory_grid.plotting.plotting import prepare_plot from marl_factory_grid.utils.plotting.plotting import prepare_plot
MODEL_MAP = None MODEL_MAP = None
+1
View File
@@ -120,6 +120,7 @@ class ConfigExplainer:
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''): def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
filepath = Path(filepath) filepath = Path(filepath)
yaml.Dumper.ignore_aliases = lambda *args: True
with filepath.open('w') as f: with filepath.open('w') as f:
yaml.dump(data, f, encoding='utf-8') yaml.dump(data, f, encoding='utf-8')
print(f'Example config {"for " + tag + " " if tag else " "}dumped') print(f'Example config {"for " + tag + " " if tag else " "}dumped')
+2 -2
View File
@@ -4,8 +4,8 @@ from pathlib import Path
import yaml import yaml
from marl_factory_grid.environment.factory import Factory from marl_factory_grid.environment.factory import Factory
from marl_factory_grid.logging.envmonitor import EnvMonitor from marl_factory_grid.utils.logging.envmonitor import EnvMonitor
from marl_factory_grid.logging.recorder import EnvRecorder from marl_factory_grid.utils.logging.recorder import EnvRecorder
from marl_factory_grid.modules.doors import constants as d from marl_factory_grid.modules.doors import constants as d