new plotting, omit_agent_obs

This commit is contained in:
steffen-illium
2021-06-02 18:12:56 +02:00
parent 8810955e86
commit b72013407e
5 changed files with 79 additions and 20 deletions

View File

@@ -54,6 +54,12 @@ class Register:
self_with_additional_items = self + other self_with_additional_items = self + other
return self_with_additional_items return self_with_additional_items
def keys(self):
return self._register.keys()
def items(self):
return self._register.items()
def __getitem__(self, item): def __getitem__(self, item):
return self._register[item] return self._register[item]
@@ -103,7 +109,8 @@ class BaseFactory(gym.Env):
@property @property
def observation_space(self): def observation_space(self):
if self.pomdp_radius: if self.pomdp_radius:
return spaces.Box(low=0, high=1, shape=(self._state.shape[0], self.pomdp_radius * 2 + 1, agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1,
self.pomdp_radius * 2 + 1), dtype=np.float32) self.pomdp_radius * 2 + 1), dtype=np.float32)
else: else:
space = spaces.Box(low=0, high=1, shape=self._state.shape, dtype=np.float32) space = spaces.Box(low=0, high=1, shape=self._state.shape, dtype=np.float32)
@@ -114,13 +121,15 @@ class BaseFactory(gym.Env):
return self._actions.movement_actions return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True, **kwargs): allow_square_movement=True, allow_diagonal_movement=True, allow_no_op=True,
omit_agent_slice_in_obs=False, **kwargs):
self.allow_no_op = allow_no_op self.allow_no_op = allow_no_op
self.allow_diagonal_movement = allow_diagonal_movement self.allow_diagonal_movement = allow_diagonal_movement
self.allow_square_movement = allow_square_movement self.allow_square_movement = allow_square_movement
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
self.pomdp_radius = pomdp_radius self.pomdp_radius = pomdp_radius
self.omit_agent_slice_in_obs = omit_agent_slice_in_obs
self.done_at_collision = False self.done_at_collision = False
_actions = Actions(allow_square_movement=self.allow_square_movement, _actions = Actions(allow_square_movement=self.allow_square_movement,
@@ -132,6 +141,8 @@ class BaseFactory(gym.Env):
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
) )
self._state_slices = StateSlice(n_agents) self._state_slices = StateSlice(n_agents)
if 'additional_slices' in kwargs:
self._state_slices.register_additional_items(kwargs.get('additional_slices'))
self.reset() self.reset()
@property @property
@@ -162,7 +173,7 @@ class BaseFactory(gym.Env):
# state.shape = level, agent 1,..., agent n, # state.shape = level, agent 1,..., agent n,
self._state = np.concatenate((np.expand_dims(self._level, axis=0), agents), axis=0) self._state = np.concatenate((np.expand_dims(self._level, axis=0), agents), axis=0)
# Returns State # Returns State
return self._return_state() return None
def _return_state(self): def _return_state(self):
if self.pomdp_radius: if self.pomdp_radius:
@@ -181,7 +192,15 @@ class BaseFactory(gym.Env):
obs = obs_padded obs = obs_padded
else: else:
obs = self._state obs = self._state
return obs if self.omit_agent_slice_in_obs:
if obs.shape != (3, 5, 5):
print('Shiiiiiit')
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
if obs_new.shape != self.observation_space.shape:
print('Shiiiiiit')
return obs_new
else:
return obs
def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool): def do_additional_actions(self, agent_i: int, action: int) -> ((int, int), bool):
raise NotImplementedError raise NotImplementedError

View File

@@ -37,8 +37,7 @@ class SimpleFactory(BaseFactory):
self.dirt_properties = dirt_properties self.dirt_properties = dirt_properties
self.verbose = verbose self.verbose = verbose
self.max_dirt = 20 self.max_dirt = 20
super(SimpleFactory, self).__init__(*args, **kwargs) super(SimpleFactory, self).__init__(*args, additional_slices='dirt', **kwargs)
self._state_slices.register_additional_items('dirt')
self._renderer = None # expensive - don't use it when not required ! self._renderer = None # expensive - don't use it when not required !
def render(self): def render(self):

View File

@@ -12,12 +12,11 @@ class MonitorCallback(BaseCallback):
ext = 'png' ext = 'png'
def __init__(self, env, filepath=Path('debug_out/monitor.pick'), plotting=True): def __init__(self, filepath=Path('debug_out/monitor.pick'), plotting=True):
super(MonitorCallback, self).__init__() super(MonitorCallback, self).__init__()
self.filepath = Path(filepath) self.filepath = Path(filepath)
self._monitor_df = pd.DataFrame() self._monitor_df = pd.DataFrame()
self._monitor_dict = dict() self._monitor_dict = dict()
self.env = env
self.plotting = plotting self.plotting = plotting
self.started = False self.started = False
self.closed = False self.closed = False

View File

@@ -26,18 +26,19 @@ def plot(filepath, ext='png'):
plt.clf() plt.clf()
def prepare_plot(filepath, results_df, ext='png'): def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None):
results_df.Measurement = results_df.Measurement.str.replace('_', '-') df = results_df.copy()
hue_order = sorted(list(results_df.Measurement.unique())) df[hue] = df[hue].str.replace('_', '-')
hue_order = sorted(list(df[hue].unique()))
try: try:
sns.set(rc={'text.usetex': True}, style='whitegrid') sns.set(rc={'text.usetex': True}, style='whitegrid')
sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
ci=95, palette=PALETTE, hue_order=hue_order) hue_order=hue_order, hue=hue, style=style)
plot(filepath, ext=ext) # plot raises errors not lineplot! plot(filepath, ext=ext) # plot raises errors not lineplot!
except (FileNotFoundError, RuntimeError): except (FileNotFoundError, RuntimeError):
print('Struggling to plot Figure using LaTeX - going back to normal.') print('Struggling to plot Figure using LaTeX - going back to normal.')
plt.close('all') plt.close('all')
sns.set(rc={'text.usetex': False}, style='whitegrid') sns.set(rc={'text.usetex': False}, style='whitegrid')
sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', sns.lineplot(data=df, x='Episode', y='Score', hue=hue, style=style,
ci=95, palette=PALETTE, hue_order=hue_order) ci=95, palette=PALETTE, hue_order=hue_order)
plot(filepath, ext=ext) plot(filepath, ext=ext)

53
main.py
View File

@@ -1,12 +1,13 @@
import pickle import pickle
import warnings import warnings
from typing import Union from typing import Union, List
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
import time import time
import pandas as pd import pandas as pd
from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.helpers import IGNORED_DF_COLUMNS from environments.helpers import IGNORED_DF_COLUMNS
@@ -41,26 +42,64 @@ def combine_runs(run_path: Union[str, PathLike]):
value_vars=columns, var_name="Measurement", value_vars=columns, var_name="Measurement",
value_name="Score") value_name="Score")
df_melted = df_melted[df_melted['Episode'] % skip_n == 0] df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
#df_melted['Episode'] = df_melted['Episode'] * skip_n # only needed for old version
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
print('Plotting done.') print('Plotting done.')
def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]):
run_path = Path(run_path)
df_list = list()
parameter = list(parameter) if isinstance(parameter, str) else parameter
for path in run_path.iterdir():
if path.is_dir() and str(run_identifier) in path.name:
for run, monitor_file in enumerate(path.rglob('monitor_*.pick')):
with monitor_file.open('rb') as f:
monitor_df = pickle.load(f)
monitor_df['run'] = run
monitor_df['model'] = path.name.split('_')[0]
monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df)
df = pd.concat(df_list, ignore_index=True)
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
columns = [col for col in df.columns if col in parameter]
roll_n = 30
skip_n = 10
non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()
df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
value_vars=columns, var_name="Measurement",
value_name="Score")
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
style = 'Measurement' if len(columns) > 1 else None
prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
print('Plotting done.')
if __name__ == '__main__': if __name__ == '__main__':
from stable_baselines3 import PPO, DQN, A2C from stable_baselines3 import PPO, DQN, A2C
from algorithms.dqn_reg import RegDQN
dirt_props = DirtProperties() dirt_props = DirtProperties()
time_stamp = int(time.time()) time_stamp = int(time.time())
out_path = None out_path = None
for modeL_type in [A2C, PPO, DQN]: for modeL_type in [PPO, A2C, RegDQN, DQN]:
for seed in range(5): for seed in range(5):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
allow_diagonal_movement=False, allow_no_op=False, verbose=False) allow_diagonal_movement=True, allow_no_op=False, verbose=False,
omit_agent_slice_in_obs=True)
vec_wrap = DummyVecEnv([lambda: env for _ in range(4)])
stack_wrap = VecFrameStack(vec_wrap, n_stack=4, channels_order='first')
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
@@ -70,7 +109,7 @@ if __name__ == '__main__':
out_path /= identifier out_path /= identifier
callbacks = CallbackList( callbacks = CallbackList(
[MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
) )
model.learn(total_timesteps=int(2e5), callback=callbacks) model.learn(total_timesteps=int(2e5), callback=callbacks)
@@ -82,3 +121,5 @@ if __name__ == '__main__':
if out_path: if out_path:
combine_runs(out_path.parent) combine_runs(out_path.parent)
if out_path:
compare_runs(Path('debug_out'), time_stamp, 'step_reward')