recoder adaption

This commit is contained in:
Steffen Illium
2021-10-04 17:53:19 +02:00
parent 4c21a0af7c
commit 696e520862
21 changed files with 665 additions and 380 deletions

View File

@ -13,8 +13,8 @@ from environments.factory.base.shadow_casting import Map
from environments.factory.renderer import Renderer, RenderEntity
from environments.helpers import Constants as c, Constants
from environments import helpers as h
from environments.factory.base.objects import Agent, Tile, Action, Wall
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
from environments.factory.base.objects import Agent, Tile, Action
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
from environments.utility_classes import MovementProperties
import simplejson
@ -58,7 +58,7 @@ class BaseFactory(gym.Env):
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
combin_agent_obs: bool = False, frames_to_stack=0, record_episodes=False,
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True,
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True, additional_agent_placeholder=None,
verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs):
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
if kwargs:
@ -74,6 +74,7 @@ class BaseFactory(gym.Env):
self.level_name = level_name
self._level_shape = None
self.verbose = verbose
self.additional_agent_placeholder = additional_agent_placeholder
self._renderer = None # expensive - don't use it when not required !
self._entities = Entities()
@ -141,6 +142,14 @@ class BaseFactory(gym.Env):
individual_slices=not self.combin_agent_obs)
entities.update({c.AGENT: agents})
if self.additional_agent_placeholder is not None:
# Empty Observations with either [0, 1, N(0, 1)]
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
fill_value=self.additional_agent_placeholder)
entities.update({c.AGENT_PLACEHOLDER: placeholder})
# All entities
self._entities = Entities()
self._entities.register_additional_items(entities)
@ -155,10 +164,12 @@ class BaseFactory(gym.Env):
def _init_obs_cube(self):
arrays = self._entities.observable_arrays
# FIXME: Move logic to Register
if self.omit_agent_in_obs and self.n_agents == 1:
del arrays[c.AGENT]
elif self.omit_agent_in_obs:
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
# This does not seem to be necesarry, because this case is allready handled by the Agent Register Class
# elif self.omit_agent_in_obs:
# arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
@ -273,6 +284,7 @@ class BaseFactory(gym.Env):
agent_pos_is_omitted = False
agent_omit_idx = None
if self.omit_agent_in_obs and self.n_agents == 1:
# There is only a single agent and we want to omit the agent obs, so just remove the array.
del state_array_dict[c.AGENT]
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
@ -295,6 +307,9 @@ class BaseFactory(gym.Env):
for array_idx in range(array.shape[0]):
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
if x != agent_omit_idx]]
elif key == c.AGENT and self.omit_agent_in_obs and self.combin_agent_obs:
z = 1
self._obs_cube[running_idx: running_idx + z] = array
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
@ -499,12 +514,8 @@ class BaseFactory(gym.Env):
def _summarize_state(self):
summary = {f'{REC_TAC}step': self._steps}
if self._steps == 0:
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()},
'FactoryName': self.__class__.__name__})
for entity_group in self._entities:
if not isinstance(entity_group, WallTiles):
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
return summary
def print(self, string):

View File

@ -93,11 +93,11 @@ class Entity(Object):
return self._tile
def __init__(self, tile, **kwargs):
super(Entity, self).__init__(**kwargs)
super().__init__(**kwargs)
self._tile = tile
tile.enter(self)
def summarize_state(self) -> dict:
def summarize_state(self, **_) -> dict:
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
tile=str(self.tile.name), can_collide=bool(self.can_collide))
@ -125,7 +125,7 @@ class MoveableEntity(Entity):
return last_x-curr_x, last_y-curr_y
def __init__(self, *args, **kwargs):
super(MoveableEntity, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self._last_tile = None
def move(self, next_tile):
@ -143,11 +143,34 @@ class MoveableEntity(Entity):
class Action(Object):
def __init__(self, *args, **kwargs):
super(Action, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
class PlaceHolder(MoveableEntity):
pass
def __init__(self, *args, fill_value=0, **kwargs):
super().__init__(*args, **kwargs)
self._fill_value = fill_value
@property
def last_tile(self):
return self.tile
@property
def direction_of_view(self):
return self.pos
@property
def can_collide(self):
return False
@property
def encoding(self):
return c.NO_POS.value[0]
@property
def name(self):
return "PlaceHolder"
class Tile(Object):
@ -203,8 +226,8 @@ class Tile(Object):
def __repr__(self):
return f'{self.name}(@{self.pos})'
def summarize_state(self):
return dict(name=self.name, x=self.x, y=self.y)
def summarize_state(self, **_):
return dict(name=self.name, x=int(self.x), y=int(self.y))
class Wall(Tile):
@ -254,8 +277,8 @@ class Door(Entity):
if not closed_on_init:
self._open()
def summarize_state(self):
state_dict = super().summarize_state()
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
return state_dict
@ -315,7 +338,7 @@ class Agent(MoveableEntity):
self.temp_action = None
self.temp_light_map = None
def summarize_state(self):
state_dict = super().summarize_state()
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
return state_dict

View File

@ -81,8 +81,8 @@ class ObjectRegister(Register):
if self.individual_slices:
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
def summarize_states(self):
return [val.summarize_state() for val in self.values()]
def summarize_states(self, n_steps=None):
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
class EntityObjectRegister(ObjectRegister, ABC):
@ -156,23 +156,25 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
del self[name]
class PlaceHolderRegister(MovingEntityObjectRegister):
class PlaceHolders(MovingEntityObjectRegister):
_accepted_objects = PlaceHolder
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
super().__init__(*args, **kwargs)
self.fill_value = fill_value
# noinspection DuplicatedCode
def as_array(self):
self._array[:] = c.FREE_CELL.value
# noinspection PyTupleAssignmentBalance
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
if self.individual_slices:
self._array[z, x, y] += v
else:
self._array[0, x, y] += v
if isinstance(self.fill_value, int):
self._array[:] = self.fill_value
elif self.fill_value == "normal":
self._array = np.random.normal(size=self._array.shape)
if self.individual_slices:
return self._array
else:
return self._array.sum(axis=0, keepdims=True)
return self._array[None, 0]
class Entities(Register):
@ -243,6 +245,12 @@ class WallTiles(EntityObjectRegister):
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def summarize_states(self, n_steps=None):
if n_steps == h.STEPS_START:
return super(WallTiles, self).summarize_states(n_steps=n_steps)
else:
return {}
class FloorTiles(WallTiles):
@ -272,6 +280,10 @@ class FloorTiles(WallTiles):
def from_tiles(cls, tiles, *args, **kwargs):
raise RuntimeError()
def summarize_states(self, n_steps=None):
# Do not summarize
return {}
class Agents(MovingEntityObjectRegister):

View File

@ -0,0 +1,29 @@
{
"item_properties": {
"n_items": 5,
"spawn_frequency": 10,
"n_drop_off_locations": 5,
"max_dropoff_storage_size": 0,
"max_agent_inventory_capacity": 5,
"agent_can_interact": true
},
"env_seed": 2,
"movement_properties": {
"allow_square_movement": true,
"allow_diagonal_movement": true,
"allow_no_op": false
},
"level_name": "rooms",
"verbose": false,
"n_agents": 1,
"max_steps": 400,
"pomdp_r": 2,
"combin_agent_obs": true,
"omit_agent_in_obs": true,
"cast_shadows": true,
"frames_to_stack": 3,
"done_at_collision": false,
"record_episodes": false,
"parse_doors": false,
"doors_have_area": false
}

View File

@ -51,8 +51,8 @@ class Dirt(Entity):
def set_new_amount(self, amount):
self._amount = amount
def summarize_state(self):
state_dict = super().summarize_state()
def summarize_state(self, **kwargs):
state_dict = super().summarize_state(**kwargs)
state_dict.update(amount=float(self.amount))
return state_dict

View File

@ -1,7 +1,9 @@
import random
from pathlib import Path
from environments.factory.factory_dirt import DirtFactory, DirtProperties
from environments.factory.factory_item import ItemFactory, ItemProperties
from environments.logging.recorder import RecorderCallback
from environments.utility_classes import MovementProperties
@ -12,40 +14,44 @@ class DirtItemFactory(ItemFactory, DirtFactory):
if __name__ == '__main__':
with RecorderCallback(filepath=Path('debug_out') / f'recorder_xxxx.json', occupation_map=False,
trajectory_map=False) as recorder:
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
dirt_smear_amount=0.0, agent_can_interact=True)
item_props = ItemProperties(n_items=5, agent_can_interact=True)
move_props = MovementProperties(allow_diagonal_movement=True,
allow_square_movement=True,
allow_no_op=False)
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
dirt_smear_amount=0.0, agent_can_interact=True)
item_props = ItemProperties(n_items=5, agent_can_interact=True)
move_props = MovementProperties(allow_diagonal_movement=True,
allow_square_movement=True,
allow_no_op=False)
render = True
render = False
factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
level_name='rooms', max_steps=400, combin_agent_obs=True,
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
record_episodes=True, verbose=True, cast_shadows=True,
movement_properties=move_props, dirt_properties=dirt_props
)
factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
level_name='rooms', max_steps=200, combin_agent_obs=True,
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
record_episodes=True, verbose=False, cast_shadows=True,
movement_properties=move_props, dirt_properties=dirt_props
)
# noinspection DuplicatedCode
n_actions = factory.action_space.n - 1
_ = factory.observation_space
# noinspection DuplicatedCode
n_actions = factory.action_space.n - 1
_ = factory.observation_space
for epoch in range(4):
random_actions = [[random.randint(0, n_actions) for _
in range(factory.n_agents)] for _
in range(factory.max_steps + 1)]
env_state = factory.reset()
r = 0
for agent_i_action in random_actions:
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
r += step_r
if render:
factory.render()
if done_bool:
break
print(f'Factory run {epoch} done, reward is:\n {r}')
pass
for epoch in range(4):
random_actions = [[random.randint(0, n_actions) for _
in range(factory.n_agents)] for _
in range(factory.max_steps + 1)]
env_state = factory.reset()
r = 0
for agent_i_action in random_actions:
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
recorder.read_info(0, info_obj)
r += step_r
if render:
factory.render()
if done_bool:
recorder.read_done(0, done_bool)
break
print(f'Factory run {epoch} done, reward is:\n {r}')
pass

View File

@ -109,8 +109,10 @@ class Inventory(UserList):
def belongs_to_entity(self, entity):
return self.agent == entity
def summarize_state(self):
return {val.name: val.summarize_state() for val in self}
def summarize_state(self, **kwargs):
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
attr_dict.update({val.name: val.summarize_state(**kwargs) for val in self})
return attr_dict
class Inventories(ObjectRegister):
@ -176,6 +178,10 @@ class DropOffLocation(Entity):
def is_full(self):
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
def summarize_state(self, n_steps=None) -> dict:
if n_steps == h.STEPS_START:
return super().summarize_state(n_steps=n_steps)
class DropOffLocations(EntityObjectRegister):

View File

@ -7,7 +7,10 @@ from pathlib import Path
from stable_baselines3 import PPO, DQN, A2C
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
LEVELS_DIR = 'levels'
STEPS_START = 1
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
@ -16,34 +19,35 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amo
# Constants
class Constants(Enum):
WALL = '#'
WALLS = 'Walls'
FLOOR = 'Floor'
DOOR = 'D'
DANGER_ZONE = 'x'
LEVEL = 'Level'
AGENT = 'Agent'
FREE_CELL = 0
OCCUPIED_CELL = 1
SHADOWED_CELL = -1
NO_POS = (-9999, -9999)
WALL = '#'
WALLS = 'Walls'
FLOOR = 'Floor'
DOOR = 'D'
DANGER_ZONE = 'x'
LEVEL = 'Level'
AGENT = 'Agent'
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
FREE_CELL = 0
OCCUPIED_CELL = 1
SHADOWED_CELL = -1
NO_POS = (-9999, -9999)
DOORS = 'Doors'
CLOSED_DOOR = 'closed'
OPEN_DOOR = 'open'
DOORS = 'Doors'
CLOSED_DOOR = 'closed'
OPEN_DOOR = 'open'
ACTION = 'action'
COLLISIONS = 'collision'
VALID = 'valid'
NOT_VALID = 'not_valid'
ACTION = 'action'
COLLISIONS = 'collision'
VALID = 'valid'
NOT_VALID = 'not_valid'
# Dirt Env
DIRT = 'Dirt'
DIRT = 'Dirt'
# Item Env
ITEM = 'Item'
INVENTORY = 'Inventory'
DROP_OFF = 'Drop_Off'
ITEM = 'Item'
INVENTORY = 'Inventory'
DROP_OFF = 'Drop_Off'
def __bool__(self):
if 'not_' in self.value:
@ -144,8 +148,6 @@ def asset_str(agent):
return c.AGENT.value, 'idle'
model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C)
if __name__ == '__main__':
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
y = one_hot_level(parsed_level)

View File

@ -6,7 +6,7 @@ from typing import List, Dict
from stable_baselines3.common.callbacks import BaseCallback
from environments.helpers import IGNORED_DF_COLUMNS
from environments.logging.plotting import prepare_plot
import pandas as pd
@ -14,85 +14,76 @@ class MonitorCallback(BaseCallback):
ext = 'png'
def __init__(self, filepath=Path('debug_out/monitor.pick'), plotting=True):
def __init__(self, filepath=Path('debug_out/monitor.pick')):
super(MonitorCallback, self).__init__()
self.filepath = Path(filepath)
self._monitor_df = pd.DataFrame()
self._monitor_dicts = defaultdict(dict)
self.plotting = plotting
self.started = False
self.closed = False
def __enter__(self):
self._on_training_start()
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._on_training_end()
self.stop()
def _on_training_start(self) -> None:
if self.started:
pass
else:
self.filepath.parent.mkdir(exist_ok=True, parents=True)
self.started = True
self.start()
pass
def _on_training_end(self) -> None:
if self.closed:
pass
else:
# self.out_file.unlink(missing_ok=True)
with self.filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
if self.plotting:
print('Monitor files were dumped to disk, now plotting....')
# %% Load MonitorList from Disk
with self.filepath.open('rb') as f:
monitor_list = pickle.load(f)
df = None
for m_idx, monitor in enumerate(monitor_list):
monitor['episode'] = m_idx
if df is None:
df = pd.DataFrame(columns=monitor.columns)
for _, row in monitor.iterrows():
df.loc[df.shape[0]] = row
if df is None: # The env exited premature, we catch it.
self.closed = True
return
for column in list(df.columns):
if column != 'episode':
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
# result.tail()
prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")))
print('Plotting done.')
self.closed = True
self.stop()
def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool:
infos = alt_infos or self.locals.get('infos', [])
if alt_dones is not None:
dones = alt_dones
elif self.locals.get('dones', None) is not None:
dones =self.locals.get('dones', None)
elif self.locals.get('done', None) is not None:
dones = self.locals.get('done', [None])
else:
dones = []
if self.started:
for env_idx, info in enumerate(self.locals.get('infos', [])):
self.read_info(env_idx, info)
for env_idx, (info, done) in enumerate(zip(infos, dones)):
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
if key not in ['terminal_observation', 'episode']
and not key.startswith('rec_')}
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
self._monitor_dicts[env_idx] = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])
else:
pass
for env_idx, done in list(
enumerate(self.locals.get('dones', []))) + list(enumerate(self.locals.get('done', []))):
self.read_done(env_idx, done)
else:
pass
return True
def read_info(self, env_idx, info: dict):
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {
key: val for key, val in info.items() if
key not in ['terminal_observation', 'episode'] and not key.startswith('rec_')}
return
def read_done(self, env_idx, done):
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
self._monitor_dicts[env_idx] = dict()
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
env_monitor_df = env_monitor_df.aggregate(
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
)
env_monitor_df['episode'] = len(self._monitor_df)
self._monitor_df = self._monitor_df.append([env_monitor_df])
else:
pass
return
def stop(self):
# self.out_file.unlink(missing_ok=True)
with self.filepath.open('wb') as f:
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
self.closed = True
def start(self):
if self.started:
pass
else:
self.filepath.parent.mkdir(exist_ok=True, parents=True)
self.started = True
pass

View File

@ -1,46 +0,0 @@
import seaborn as sns
from matplotlib import pyplot as plt
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
"#984ea3",
"#e41a1c",
"#ff7f00",
"#a65628",
"#f781bf",
"#888888",
"#a6cee3",
"#b2df8a",
"#cab2d6",
"#fb9a99",
"#fdbf6f",
)
def plot(filepath, ext='png'):
plt.tight_layout()
figure = plt.gcf()
figure.savefig(str(filepath), format=ext)
plt.show()
plt.clf()
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None):
df = results_df.copy()
df[hue] = df[hue].str.replace('_', '-')
hue_order = sorted(list(df[hue].unique()))
try:
sns.set(rc={'text.usetex': True}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style)
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext) # plot raises errors not lineplot!
except (FileNotFoundError, RuntimeError):
print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid')
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order)
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
plot(filepath, ext=ext)

View File

@ -3,11 +3,10 @@ from collections import defaultdict
from pathlib import Path
from typing import Union
import pandas as pd
import simplejson
from stable_baselines3.common.callbacks import BaseCallback
from environments.factory.base.base_factory import REC_TAC
from environments.helpers import IGNORED_DF_COLUMNS
# noinspection PyAttributeOutsideInit
@ -18,8 +17,8 @@ class RecorderCallback(BaseCallback):
self.trajectory_map = trajectory_map
self.occupation_map = occupation_map
self.filepath = Path(filepath)
self._recorder_dict = defaultdict(dict)
self._recorder_json_list = list()
self._recorder_dict = defaultdict(list)
self._recorder_out_list = list()
self.do_record: bool
self.started = False
self.closed = False
@ -27,15 +26,15 @@ class RecorderCallback(BaseCallback):
def read_info(self, env_idx, info: dict):
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
info_dict.update(episode=(self.num_timesteps + env_idx))
self._recorder_dict[env_idx][len(self._recorder_dict[env_idx])] = info_dict
self._recorder_dict[env_idx].append(info_dict)
else:
pass
return
def read_done(self, env_idx, done):
if done:
self._recorder_json_list.append(json.dumps(self._recorder_dict[env_idx]))
self._recorder_dict[env_idx] = dict()
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx]})
self._recorder_dict[env_idx] = list()
else:
pass
@ -51,8 +50,11 @@ class RecorderCallback(BaseCallback):
if self.do_record and self.started:
# self.out_file.unlink(missing_ok=True)
with self.filepath.open('w') as f:
json_list = self._recorder_json_list
json.dump(json_list, f, indent=4)
out_dict = {'episodes': self._recorder_out_list}
try:
simplejson.dump(out_dict, f, indent=4)
except TypeError:
print('Shit')
if self.occupation_map:
print('Recorder files were dumped to disk, now plotting the occupation map...')