mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-21 11:21:35 +02:00
recoder adaption
This commit is contained in:
@ -13,8 +13,8 @@ from environments.factory.base.shadow_casting import Map
|
||||
from environments.factory.renderer import Renderer, RenderEntity
|
||||
from environments.helpers import Constants as c, Constants
|
||||
from environments import helpers as h
|
||||
from environments.factory.base.objects import Agent, Tile, Action, Wall
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles
|
||||
from environments.factory.base.objects import Agent, Tile, Action
|
||||
from environments.factory.base.registers import Actions, Entities, Agents, Doors, FloorTiles, WallTiles, PlaceHolders
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
import simplejson
|
||||
@ -58,7 +58,7 @@ class BaseFactory(gym.Env):
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_r: Union[None, int] = 0,
|
||||
movement_properties: MovementProperties = MovementProperties(), parse_doors=False,
|
||||
combin_agent_obs: bool = False, frames_to_stack=0, record_episodes=False,
|
||||
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True,
|
||||
omit_agent_in_obs=False, done_at_collision=False, cast_shadows=True, additional_agent_placeholder=None,
|
||||
verbose=False, doors_have_area=True, env_seed=time.time_ns(), **kwargs):
|
||||
assert frames_to_stack != 1 and frames_to_stack >= 0, "'frames_to_stack' cannot be negative or 1."
|
||||
if kwargs:
|
||||
@ -74,6 +74,7 @@ class BaseFactory(gym.Env):
|
||||
self.level_name = level_name
|
||||
self._level_shape = None
|
||||
self.verbose = verbose
|
||||
self.additional_agent_placeholder = additional_agent_placeholder
|
||||
self._renderer = None # expensive - don't use it when not required !
|
||||
self._entities = Entities()
|
||||
|
||||
@ -141,6 +142,14 @@ class BaseFactory(gym.Env):
|
||||
individual_slices=not self.combin_agent_obs)
|
||||
entities.update({c.AGENT: agents})
|
||||
|
||||
if self.additional_agent_placeholder is not None:
|
||||
|
||||
# Empty Observations with either [0, 1, N(0, 1)]
|
||||
placeholder = PlaceHolders.from_tiles([self._NO_POS_TILE], self._level_shape,
|
||||
fill_value=self.additional_agent_placeholder)
|
||||
|
||||
entities.update({c.AGENT_PLACEHOLDER: placeholder})
|
||||
|
||||
# All entities
|
||||
self._entities = Entities()
|
||||
self._entities.register_additional_items(entities)
|
||||
@ -155,10 +164,12 @@ class BaseFactory(gym.Env):
|
||||
def _init_obs_cube(self):
|
||||
arrays = self._entities.observable_arrays
|
||||
|
||||
# FIXME: Move logic to Register
|
||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||
del arrays[c.AGENT]
|
||||
elif self.omit_agent_in_obs:
|
||||
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
|
||||
# This does not seem to be necesarry, because this case is allready handled by the Agent Register Class
|
||||
# elif self.omit_agent_in_obs:
|
||||
# arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
|
||||
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
|
||||
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
|
||||
|
||||
@ -273,6 +284,7 @@ class BaseFactory(gym.Env):
|
||||
agent_pos_is_omitted = False
|
||||
agent_omit_idx = None
|
||||
if self.omit_agent_in_obs and self.n_agents == 1:
|
||||
# There is only a single agent and we want to omit the agent obs, so just remove the array.
|
||||
del state_array_dict[c.AGENT]
|
||||
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
|
||||
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
|
||||
@ -295,6 +307,9 @@ class BaseFactory(gym.Env):
|
||||
for array_idx in range(array.shape[0]):
|
||||
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
|
||||
if x != agent_omit_idx]]
|
||||
elif key == c.AGENT and self.omit_agent_in_obs and self.combin_agent_obs:
|
||||
z = 1
|
||||
self._obs_cube[running_idx: running_idx + z] = array
|
||||
else:
|
||||
z = array.shape[0]
|
||||
self._obs_cube[running_idx: running_idx+z] = array
|
||||
@ -499,12 +514,8 @@ class BaseFactory(gym.Env):
|
||||
def _summarize_state(self):
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
if self._steps == 0:
|
||||
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()},
|
||||
'FactoryName': self.__class__.__name__})
|
||||
for entity_group in self._entities:
|
||||
if not isinstance(entity_group, WallTiles):
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states(n_steps=self._steps)})
|
||||
return summary
|
||||
|
||||
def print(self, string):
|
||||
|
@ -93,11 +93,11 @@ class Entity(Object):
|
||||
return self._tile
|
||||
|
||||
def __init__(self, tile, **kwargs):
|
||||
super(Entity, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self._tile = tile
|
||||
tile.enter(self)
|
||||
|
||||
def summarize_state(self) -> dict:
|
||||
def summarize_state(self, **_) -> dict:
|
||||
return dict(name=str(self.name), x=int(self.x), y=int(self.y),
|
||||
tile=str(self.tile.name), can_collide=bool(self.can_collide))
|
||||
|
||||
@ -125,7 +125,7 @@ class MoveableEntity(Entity):
|
||||
return last_x-curr_x, last_y-curr_y
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MoveableEntity, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._last_tile = None
|
||||
|
||||
def move(self, next_tile):
|
||||
@ -143,11 +143,34 @@ class MoveableEntity(Entity):
|
||||
class Action(Object):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Action, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PlaceHolder(MoveableEntity):
|
||||
pass
|
||||
|
||||
def __init__(self, *args, fill_value=0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._fill_value = fill_value
|
||||
|
||||
@property
|
||||
def last_tile(self):
|
||||
return self.tile
|
||||
|
||||
@property
|
||||
def direction_of_view(self):
|
||||
return self.pos
|
||||
|
||||
@property
|
||||
def can_collide(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return c.NO_POS.value[0]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "PlaceHolder"
|
||||
|
||||
|
||||
class Tile(Object):
|
||||
@ -203,8 +226,8 @@ class Tile(Object):
|
||||
def __repr__(self):
|
||||
return f'{self.name}(@{self.pos})'
|
||||
|
||||
def summarize_state(self):
|
||||
return dict(name=self.name, x=self.x, y=self.y)
|
||||
def summarize_state(self, **_):
|
||||
return dict(name=self.name, x=int(self.x), y=int(self.y))
|
||||
|
||||
|
||||
class Wall(Tile):
|
||||
@ -254,8 +277,8 @@ class Door(Entity):
|
||||
if not closed_on_init:
|
||||
self._open()
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(state=str(self.str_state), time_to_close=int(self.time_to_close))
|
||||
return state_dict
|
||||
|
||||
@ -315,7 +338,7 @@ class Agent(MoveableEntity):
|
||||
self.temp_action = None
|
||||
self.temp_light_map = None
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(valid=bool(self.temp_valid), action=str(self.temp_action))
|
||||
return state_dict
|
||||
|
@ -81,8 +81,8 @@ class ObjectRegister(Register):
|
||||
if self.individual_slices:
|
||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
||||
|
||||
def summarize_states(self):
|
||||
return [val.summarize_state() for val in self.values()]
|
||||
def summarize_states(self, n_steps=None):
|
||||
return [val.summarize_state(n_steps=n_steps) for val in self.values()]
|
||||
|
||||
|
||||
class EntityObjectRegister(ObjectRegister, ABC):
|
||||
@ -156,23 +156,25 @@ class MovingEntityObjectRegister(EntityObjectRegister, ABC):
|
||||
del self[name]
|
||||
|
||||
|
||||
class PlaceHolderRegister(MovingEntityObjectRegister):
|
||||
class PlaceHolders(MovingEntityObjectRegister):
|
||||
|
||||
_accepted_objects = PlaceHolder
|
||||
|
||||
def __init__(self, *args, fill_value: Union[str, int] = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fill_value = fill_value
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
def as_array(self):
|
||||
self._array[:] = c.FREE_CELL.value
|
||||
# noinspection PyTupleAssignmentBalance
|
||||
for z, x, y, v in zip(range(len(self)), *zip(*[x.pos for x in self]), [x.encoding for x in self]):
|
||||
if self.individual_slices:
|
||||
self._array[z, x, y] += v
|
||||
else:
|
||||
self._array[0, x, y] += v
|
||||
if isinstance(self.fill_value, int):
|
||||
self._array[:] = self.fill_value
|
||||
elif self.fill_value == "normal":
|
||||
self._array = np.random.normal(size=self._array.shape)
|
||||
|
||||
if self.individual_slices:
|
||||
return self._array
|
||||
else:
|
||||
return self._array.sum(axis=0, keepdims=True)
|
||||
return self._array[None, 0]
|
||||
|
||||
|
||||
class Entities(Register):
|
||||
@ -243,6 +245,12 @@ class WallTiles(EntityObjectRegister):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
if n_steps == h.STEPS_START:
|
||||
return super(WallTiles, self).summarize_states(n_steps=n_steps)
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
class FloorTiles(WallTiles):
|
||||
|
||||
@ -272,6 +280,10 @@ class FloorTiles(WallTiles):
|
||||
def from_tiles(cls, tiles, *args, **kwargs):
|
||||
raise RuntimeError()
|
||||
|
||||
def summarize_states(self, n_steps=None):
|
||||
# Do not summarize
|
||||
return {}
|
||||
|
||||
|
||||
class Agents(MovingEntityObjectRegister):
|
||||
|
||||
|
29
environments/factory/env_item_default_param.json
Normal file
29
environments/factory/env_item_default_param.json
Normal file
@ -0,0 +1,29 @@
|
||||
{
|
||||
"item_properties": {
|
||||
"n_items": 5,
|
||||
"spawn_frequency": 10,
|
||||
"n_drop_off_locations": 5,
|
||||
"max_dropoff_storage_size": 0,
|
||||
"max_agent_inventory_capacity": 5,
|
||||
"agent_can_interact": true
|
||||
},
|
||||
"env_seed": 2,
|
||||
"movement_properties": {
|
||||
"allow_square_movement": true,
|
||||
"allow_diagonal_movement": true,
|
||||
"allow_no_op": false
|
||||
},
|
||||
"level_name": "rooms",
|
||||
"verbose": false,
|
||||
"n_agents": 1,
|
||||
"max_steps": 400,
|
||||
"pomdp_r": 2,
|
||||
"combin_agent_obs": true,
|
||||
"omit_agent_in_obs": true,
|
||||
"cast_shadows": true,
|
||||
"frames_to_stack": 3,
|
||||
"done_at_collision": false,
|
||||
"record_episodes": false,
|
||||
"parse_doors": false,
|
||||
"doors_have_area": false
|
||||
}
|
@ -51,8 +51,8 @@ class Dirt(Entity):
|
||||
def set_new_amount(self, amount):
|
||||
self._amount = amount
|
||||
|
||||
def summarize_state(self):
|
||||
state_dict = super().summarize_state()
|
||||
def summarize_state(self, **kwargs):
|
||||
state_dict = super().summarize_state(**kwargs)
|
||||
state_dict.update(amount=float(self.amount))
|
||||
return state_dict
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from environments.factory.factory_dirt import DirtFactory, DirtProperties
|
||||
from environments.factory.factory_item import ItemFactory, ItemProperties
|
||||
from environments.logging.recorder import RecorderCallback
|
||||
from environments.utility_classes import MovementProperties
|
||||
|
||||
|
||||
@ -12,40 +14,44 @@ class DirtItemFactory(ItemFactory, DirtFactory):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with RecorderCallback(filepath=Path('debug_out') / f'recorder_xxxx.json', occupation_map=False,
|
||||
trajectory_map=False) as recorder:
|
||||
|
||||
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
|
||||
max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
|
||||
dirt_smear_amount=0.0, agent_can_interact=True)
|
||||
item_props = ItemProperties(n_items=5, agent_can_interact=True)
|
||||
move_props = MovementProperties(allow_diagonal_movement=True,
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
|
||||
render = True
|
||||
render = False
|
||||
|
||||
factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
||||
level_name='rooms', max_steps=400, combin_agent_obs=True,
|
||||
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
|
||||
record_episodes=True, verbose=True, cast_shadows=True,
|
||||
movement_properties=move_props, dirt_properties=dirt_props
|
||||
)
|
||||
factory = DirtItemFactory(n_agents=1, done_at_collision=False, frames_to_stack=0,
|
||||
level_name='rooms', max_steps=200, combin_agent_obs=True,
|
||||
omit_agent_in_obs=True, parse_doors=True, pomdp_r=3,
|
||||
record_episodes=True, verbose=False, cast_shadows=True,
|
||||
movement_properties=move_props, dirt_properties=dirt_props
|
||||
)
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
_ = factory.observation_space
|
||||
# noinspection DuplicatedCode
|
||||
n_actions = factory.action_space.n - 1
|
||||
_ = factory.observation_space
|
||||
|
||||
for epoch in range(4):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps + 1)]
|
||||
env_state = factory.reset()
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
r += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
pass
|
||||
for epoch in range(4):
|
||||
random_actions = [[random.randint(0, n_actions) for _
|
||||
in range(factory.n_agents)] for _
|
||||
in range(factory.max_steps + 1)]
|
||||
env_state = factory.reset()
|
||||
r = 0
|
||||
for agent_i_action in random_actions:
|
||||
env_state, step_r, done_bool, info_obj = factory.step(agent_i_action)
|
||||
recorder.read_info(0, info_obj)
|
||||
r += step_r
|
||||
if render:
|
||||
factory.render()
|
||||
if done_bool:
|
||||
recorder.read_done(0, done_bool)
|
||||
break
|
||||
print(f'Factory run {epoch} done, reward is:\n {r}')
|
||||
pass
|
||||
|
@ -109,8 +109,10 @@ class Inventory(UserList):
|
||||
def belongs_to_entity(self, entity):
|
||||
return self.agent == entity
|
||||
|
||||
def summarize_state(self):
|
||||
return {val.name: val.summarize_state() for val in self}
|
||||
def summarize_state(self, **kwargs):
|
||||
attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
|
||||
attr_dict.update({val.name: val.summarize_state(**kwargs) for val in self})
|
||||
return attr_dict
|
||||
|
||||
|
||||
class Inventories(ObjectRegister):
|
||||
@ -176,6 +178,10 @@ class DropOffLocation(Entity):
|
||||
def is_full(self):
|
||||
return False if not self.storage.maxlen else self.storage.maxlen == len(self.storage)
|
||||
|
||||
def summarize_state(self, n_steps=None) -> dict:
|
||||
if n_steps == h.STEPS_START:
|
||||
return super().summarize_state(n_steps=n_steps)
|
||||
|
||||
|
||||
class DropOffLocations(EntityObjectRegister):
|
||||
|
||||
|
@ -7,7 +7,10 @@ from pathlib import Path
|
||||
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
|
||||
MODEL_MAP = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||
|
||||
LEVELS_DIR = 'levels'
|
||||
STEPS_START = 1
|
||||
|
||||
TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amount',
|
||||
@ -16,34 +19,35 @@ IGNORED_DF_COLUMNS = ['Episode', 'Run', 'train_step', 'step', 'index', 'dirt_amo
|
||||
|
||||
# Constants
|
||||
class Constants(Enum):
|
||||
WALL = '#'
|
||||
WALLS = 'Walls'
|
||||
FLOOR = 'Floor'
|
||||
DOOR = 'D'
|
||||
DANGER_ZONE = 'x'
|
||||
LEVEL = 'Level'
|
||||
AGENT = 'Agent'
|
||||
FREE_CELL = 0
|
||||
OCCUPIED_CELL = 1
|
||||
SHADOWED_CELL = -1
|
||||
NO_POS = (-9999, -9999)
|
||||
WALL = '#'
|
||||
WALLS = 'Walls'
|
||||
FLOOR = 'Floor'
|
||||
DOOR = 'D'
|
||||
DANGER_ZONE = 'x'
|
||||
LEVEL = 'Level'
|
||||
AGENT = 'Agent'
|
||||
AGENT_PLACEHOLDER = 'AGENT_PLACEHOLDER'
|
||||
FREE_CELL = 0
|
||||
OCCUPIED_CELL = 1
|
||||
SHADOWED_CELL = -1
|
||||
NO_POS = (-9999, -9999)
|
||||
|
||||
DOORS = 'Doors'
|
||||
CLOSED_DOOR = 'closed'
|
||||
OPEN_DOOR = 'open'
|
||||
DOORS = 'Doors'
|
||||
CLOSED_DOOR = 'closed'
|
||||
OPEN_DOOR = 'open'
|
||||
|
||||
ACTION = 'action'
|
||||
COLLISIONS = 'collision'
|
||||
VALID = 'valid'
|
||||
NOT_VALID = 'not_valid'
|
||||
ACTION = 'action'
|
||||
COLLISIONS = 'collision'
|
||||
VALID = 'valid'
|
||||
NOT_VALID = 'not_valid'
|
||||
|
||||
# Dirt Env
|
||||
DIRT = 'Dirt'
|
||||
DIRT = 'Dirt'
|
||||
|
||||
# Item Env
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
ITEM = 'Item'
|
||||
INVENTORY = 'Inventory'
|
||||
DROP_OFF = 'Drop_Off'
|
||||
|
||||
def __bool__(self):
|
||||
if 'not_' in self.value:
|
||||
@ -144,8 +148,6 @@ def asset_str(agent):
|
||||
return c.AGENT.value, 'idle'
|
||||
|
||||
|
||||
model_map = dict(PPO=PPO, DQN=DQN, A2C=A2C)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parsed_level = parse_level(Path(__file__).parent / 'factory' / 'levels' / 'simple.txt')
|
||||
y = one_hot_level(parsed_level)
|
||||
|
@ -6,7 +6,7 @@ from typing import List, Dict
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
from environments.helpers import IGNORED_DF_COLUMNS
|
||||
from environments.logging.plotting import prepare_plot
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@ -14,85 +14,76 @@ class MonitorCallback(BaseCallback):
|
||||
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, filepath=Path('debug_out/monitor.pick'), plotting=True):
|
||||
def __init__(self, filepath=Path('debug_out/monitor.pick')):
|
||||
super(MonitorCallback, self).__init__()
|
||||
self.filepath = Path(filepath)
|
||||
self._monitor_df = pd.DataFrame()
|
||||
self._monitor_dicts = defaultdict(dict)
|
||||
self.plotting = plotting
|
||||
self.started = False
|
||||
self.closed = False
|
||||
|
||||
def __enter__(self):
|
||||
self._on_training_start()
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._on_training_end()
|
||||
self.stop()
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
if self.started:
|
||||
pass
|
||||
else:
|
||||
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
self.started = True
|
||||
self.start()
|
||||
pass
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
if self.closed:
|
||||
pass
|
||||
else:
|
||||
# self.out_file.unlink(missing_ok=True)
|
||||
with self.filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if self.plotting:
|
||||
print('Monitor files were dumped to disk, now plotting....')
|
||||
|
||||
# %% Load MonitorList from Disk
|
||||
with self.filepath.open('rb') as f:
|
||||
monitor_list = pickle.load(f)
|
||||
df = None
|
||||
for m_idx, monitor in enumerate(monitor_list):
|
||||
monitor['episode'] = m_idx
|
||||
if df is None:
|
||||
df = pd.DataFrame(columns=monitor.columns)
|
||||
for _, row in monitor.iterrows():
|
||||
df.loc[df.shape[0]] = row
|
||||
if df is None: # The env exited premature, we catch it.
|
||||
self.closed = True
|
||||
return
|
||||
for column in list(df.columns):
|
||||
if column != 'episode':
|
||||
df[f'{column}_roll'] = df[column].rolling(window=50).mean()
|
||||
# result.tail()
|
||||
prepare_plot(filepath=self.filepath, results_df=df.filter(regex=(".+_roll")))
|
||||
print('Plotting done.')
|
||||
self.closed = True
|
||||
self.stop()
|
||||
|
||||
def _on_step(self, alt_infos: List[Dict] = None, alt_dones: List[bool] = None) -> bool:
|
||||
infos = alt_infos or self.locals.get('infos', [])
|
||||
if alt_dones is not None:
|
||||
dones = alt_dones
|
||||
elif self.locals.get('dones', None) is not None:
|
||||
dones =self.locals.get('dones', None)
|
||||
elif self.locals.get('done', None) is not None:
|
||||
dones = self.locals.get('done', [None])
|
||||
else:
|
||||
dones = []
|
||||
if self.started:
|
||||
for env_idx, info in enumerate(self.locals.get('infos', [])):
|
||||
self.read_info(env_idx, info)
|
||||
|
||||
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
||||
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {key: val for key, val in info.items()
|
||||
if key not in ['terminal_observation', 'episode']
|
||||
and not key.startswith('rec_')}
|
||||
if done:
|
||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
|
||||
self._monitor_dicts[env_idx] = dict()
|
||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
env_monitor_df = env_monitor_df.aggregate(
|
||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||
)
|
||||
env_monitor_df['episode'] = len(self._monitor_df)
|
||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
||||
else:
|
||||
pass
|
||||
for env_idx, done in list(
|
||||
enumerate(self.locals.get('dones', []))) + list(enumerate(self.locals.get('done', []))):
|
||||
self.read_done(env_idx, done)
|
||||
else:
|
||||
pass
|
||||
return True
|
||||
|
||||
def read_info(self, env_idx, info: dict):
|
||||
self._monitor_dicts[env_idx][len(self._monitor_dicts[env_idx])] = {
|
||||
key: val for key, val in info.items() if
|
||||
key not in ['terminal_observation', 'episode'] and not key.startswith('rec_')}
|
||||
return
|
||||
|
||||
def read_done(self, env_idx, done):
|
||||
if done:
|
||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dicts[env_idx], orient='index')
|
||||
self._monitor_dicts[env_idx] = dict()
|
||||
columns = [col for col in env_monitor_df.columns if col not in IGNORED_DF_COLUMNS]
|
||||
env_monitor_df = env_monitor_df.aggregate(
|
||||
{col: 'mean' if col.endswith('ount') else 'sum' for col in columns}
|
||||
)
|
||||
env_monitor_df['episode'] = len(self._monitor_df)
|
||||
self._monitor_df = self._monitor_df.append([env_monitor_df])
|
||||
else:
|
||||
pass
|
||||
return
|
||||
|
||||
def stop(self):
|
||||
# self.out_file.unlink(missing_ok=True)
|
||||
with self.filepath.open('wb') as f:
|
||||
pickle.dump(self._monitor_df.reset_index(), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.closed = True
|
||||
|
||||
def start(self):
|
||||
if self.started:
|
||||
pass
|
||||
else:
|
||||
self.filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
self.started = True
|
||||
pass
|
||||
|
@ -1,46 +0,0 @@
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
|
||||
def plot(filepath, ext='png'):
|
||||
plt.tight_layout()
|
||||
figure = plt.gcf()
|
||||
figure.savefig(str(filepath), format=ext)
|
||||
plt.show()
|
||||
plt.clf()
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None):
|
||||
df = results_df.copy()
|
||||
df[hue] = df[hue].str.replace('_', '-')
|
||||
hue_order = sorted(list(df[hue].unique()))
|
||||
try:
|
||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||
hue_order=hue_order, hue=hue, style=style)
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
print('Struggling to plot Figure using LaTeX - going back to normal.')
|
||||
plt.close('all')
|
||||
sns.set(rc={'text.usetex': False}, style='whitegrid')
|
||||
lineplot = sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
|
||||
ci=95, palette=PALETTE, hue_order=hue_order)
|
||||
# lineplot.set_title(f'{sorted(list(df["Measurement"].unique()))}')
|
||||
plot(filepath, ext=ext)
|
@ -3,11 +3,10 @@ from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import simplejson
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
from environments.factory.base.base_factory import REC_TAC
|
||||
from environments.helpers import IGNORED_DF_COLUMNS
|
||||
|
||||
|
||||
# noinspection PyAttributeOutsideInit
|
||||
@ -18,8 +17,8 @@ class RecorderCallback(BaseCallback):
|
||||
self.trajectory_map = trajectory_map
|
||||
self.occupation_map = occupation_map
|
||||
self.filepath = Path(filepath)
|
||||
self._recorder_dict = defaultdict(dict)
|
||||
self._recorder_json_list = list()
|
||||
self._recorder_dict = defaultdict(list)
|
||||
self._recorder_out_list = list()
|
||||
self.do_record: bool
|
||||
self.started = False
|
||||
self.closed = False
|
||||
@ -27,15 +26,15 @@ class RecorderCallback(BaseCallback):
|
||||
def read_info(self, env_idx, info: dict):
|
||||
if info_dict := {key.replace(REC_TAC, ''): val for key, val in info.items() if key.startswith(f'{REC_TAC}')}:
|
||||
info_dict.update(episode=(self.num_timesteps + env_idx))
|
||||
self._recorder_dict[env_idx][len(self._recorder_dict[env_idx])] = info_dict
|
||||
self._recorder_dict[env_idx].append(info_dict)
|
||||
else:
|
||||
pass
|
||||
return
|
||||
|
||||
def read_done(self, env_idx, done):
|
||||
if done:
|
||||
self._recorder_json_list.append(json.dumps(self._recorder_dict[env_idx]))
|
||||
self._recorder_dict[env_idx] = dict()
|
||||
self._recorder_out_list.append({'steps': self._recorder_dict[env_idx]})
|
||||
self._recorder_dict[env_idx] = list()
|
||||
else:
|
||||
pass
|
||||
|
||||
@ -51,8 +50,11 @@ class RecorderCallback(BaseCallback):
|
||||
if self.do_record and self.started:
|
||||
# self.out_file.unlink(missing_ok=True)
|
||||
with self.filepath.open('w') as f:
|
||||
json_list = self._recorder_json_list
|
||||
json.dump(json_list, f, indent=4)
|
||||
out_dict = {'episodes': self._recorder_out_list}
|
||||
try:
|
||||
simplejson.dump(out_dict, f, indent=4)
|
||||
except TypeError:
|
||||
print('Shit')
|
||||
|
||||
if self.occupation_map:
|
||||
print('Recorder files were dumped to disk, now plotting the occupation map...')
|
||||
|
Reference in New Issue
Block a user