major redesign ob observations and entittes

This commit is contained in:
Steffen Illium
2023-06-09 14:04:17 +02:00
parent 901fbcbc32
commit c552c35f66
161 changed files with 4458 additions and 4163 deletions

View File

View File

@@ -0,0 +1,64 @@
import pickle
from os import PathLike
from pathlib import Path
from typing import Union
from gymnasium import Wrapper
from environment.utils.helpers import IGNORED_DF_COLUMNS
from environment.factory import REC_TAC
import pandas as pd
from plotting.compare_runs import plot_single_run
class EnvMonitor(Wrapper):
ext = 'png'
def __init__(self, env, filepath: Union[str, PathLike] = None):
super(EnvMonitor, self).__init__(env)
self._filepath = filepath
self._monitor_df = pd.DataFrame()
self._monitor_dict = dict()
def __getattr__(self, item):
return getattr(self.unwrapped, item)
def step(self, action):
obs_type, obs, reward, done, info = self.env.step(action)
self._read_info(info)
self._read_done(done)
return obs_type, obs, reward, done, info
def reset(self):
return self.unwrapped.reset()
def _read_info(self, info: dict):
self._monitor_dict[len(self._monitor_dict)] = {
key: val for key, val in info.items() if
key not in ['terminal_observation', 'episode'] and not key.startswith(REC_TAC)}
return
def _read_done(self, done):
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
self._monitor_dict = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])
else:
pass
return
def save_run(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
filepath = Path(filepath or self._filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
if auto_plotting_keys:
plot_single_run(filepath, column_keys=auto_plotting_keys)

View File

@@ -0,0 +1,152 @@
import warnings
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
from gymnasium import Wrapper
import numpy as np
import pandas as pd
import simplejson
from environment.factory import REC_TAC
class EnvRecorder(Wrapper):
def __init__(self, env, entities: str = 'all', filepath: Union[str, PathLike] = None, freq: int = 0):
super(EnvRecorder, self).__init__(env)
self.filepath = filepath
self.freq = freq
self._recorder_dict = defaultdict(list)
self._recorder_out_list = list()
self._episode_counter = 1
self._do_record_dict = defaultdict(lambda: False)
if isinstance(entities, str):
if entities.lower() == 'all':
self._entities = None
else:
self._entities = [entities]
else:
self._entities = entities
def __getattr__(self, item):
return getattr(self.unwrapped, item)
def reset(self):
self._on_training_start()
return self.unwrapped.reset()
def _on_training_start(self) -> None:
assert self.start_recording()
def _read_info(self, env_idx, info: dict):
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
if self._entities:
info_dict = {k: v for k, v in info_dict.items() if k in self._entities}
self._recorder_dict[env_idx].append(info_dict)
else:
pass
return True
def _read_done(self, env_idx, done):
if done:
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
'episode': self._episode_counter})
self._recorder_dict[env_idx] = list()
else:
pass
def step(self, actions):
step_result = self.unwrapped.step(actions)
if self.do_record_episode(0):
info = step_result[-1]
self._read_info(0, info)
if self._do_record_dict[0]:
self._read_done(0, step_result[-2])
return step_result
def finalize(self):
self._on_training_end()
return True
def save_records(self, filepath: Union[Path, str, None] = None,
only_deltas=True,
save_occupation_map=False,
save_trajectory_map=False,
):
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
# cls.out_file.unlink(missing_ok=True)
with filepath.open('w') as f:
if only_deltas:
from deepdiff import DeepDiff, Delta
diff_dict = [DeepDiff(t1,t2, ignore_order=True)
for t1, t2 in zip(self._recorder_out_list, self._recorder_out_list[1:])
]
out_dict = {'episodes': diff_dict}
else:
out_dict = {'episodes': self._recorder_out_list}
out_dict.update(
{'n_episodes': self._episode_counter,
'env_params': self.env.params,
'header': self.env.summarize_header
})
try:
simplejson.dump(out_dict, f, indent=4)
except TypeError:
print('Shit')
if save_occupation_map:
a = np.zeros((15, 15))
# noinspection PyTypeChecker
for episode in out_dict['episodes']:
df = pd.DataFrame([y for x in episode['steps'] for y in x['Agents']])
b = list(df[['x', 'y']].to_records(index=False))
np.add.at(a, tuple(zip(*b)), 1)
# a = np.rot90(a)
import seaborn as sns
from matplotlib import pyplot as plt
hm = sns.heatmap(data=a)
hm.set_title('Very Nice Heatmap')
plt.show()
if save_trajectory_map:
raise NotImplementedError('This has not yet been implemented.')
def do_record_episode(self, env_idx):
if not self._recorder_dict[env_idx]:
if self.freq:
self._do_record_dict[env_idx] = (self.freq == -1) or (self._episode_counter % self.freq) == 0
else:
self._do_record_dict[env_idx] = False
warnings.warn('You did wrap your Environment with a recorder, but set the freq to zero\n'
'Nothing will be recorded')
self._episode_counter += 1
else:
pass
return self._do_record_dict[env_idx]
def _on_step(self) -> bool:
for env_idx, info in enumerate(self.locals.get('infos', [])):
if self._do_record_dict[env_idx]:
self._read_info(env_idx, info)
dones = list(enumerate(self.locals.get('dones', [])))
dones.extend(list(enumerate(self.locals.get('done', []))))
for env_idx, done in dones:
if self._do_record_dict[env_idx]:
self._read_done(env_idx, done)
return True
def _on_training_end(self) -> None:
for env_idx in range(len(self._recorder_dict)):
if self._recorder_dict[env_idx]:
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx],
'episode': self._episode_counter})
pass