diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index d14e775..4954c69 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -6,6 +6,8 @@ import gym import numpy as np from gym import spaces +import yaml + from environments import helpers as h @@ -191,6 +193,7 @@ class BaseFactory(gym.Env): abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs obs = obs_padded else: + assert not self.omit_agent_slice_in_obs obs = self._state if self.omit_agent_slice_in_obs: if obs.shape != (3, 5, 5): @@ -315,7 +318,9 @@ class BaseFactory(gym.Env): raise NotImplementedError def save_params(self, filepath: Path): - d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')} + d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')} filepath.parent.mkdir(parents=True, exist_ok=True) + with filepath.open('wb') as f: + # yaml.dump(d, f) pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 33acdc5..8b89116 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -14,14 +14,15 @@ from environments.factory.renderer import Renderer, Entity DIRT_INDEX = -1 CLEAN_UP_ACTION = 'clean_up' + @dataclass class DirtProperties: - clean_amount = 2 # How much does the robot clean with one action. - max_spawn_ratio = 0.2 # On max how much tiles does the dirt spawn in percent. - gain_amount = 0.5 # How much dirt does spawn per tile - spawn_frequency = 5 # Spawn Frequency in Steps - max_local_amount = 1 # Max dirt amount per tile. - max_global_amount = 20 # Max dirt amount in the whole environment. + clean_amount: int = 2 # How much does the robot clean with one action. + max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent. + gain_amount: float = 0.5 # How much dirt does spawn per tile + spawn_frequency: int = 5 # Spawn Frequency in Steps + max_local_amount: int = 1 # Max dirt amount per tile. + max_global_amount: int = 20 # Max dirt amount in the whole environment. class SimpleFactory(BaseFactory): @@ -93,11 +94,11 @@ class SimpleFactory(BaseFactory): def step(self, actions): _, r, done, info = super(SimpleFactory, self).step(actions) - if not self.next_dirt_spawn: + if not self._next_dirt_spawn: self.spawn_dirt() - self.next_dirt_spawn = self.dirt_properties.spawn_frequency + self._next_dirt_spawn = self.dirt_properties.spawn_frequency else: - self.next_dirt_spawn -= 1 + self._next_dirt_spawn -= 1 obs = self._return_state() return obs, r, done, info @@ -117,7 +118,7 @@ class SimpleFactory(BaseFactory): dirt_slice = np.zeros((1, *self._state.shape[1:])) self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice self.spawn_dirt() - self.next_dirt_spawn = self.dirt_properties.spawn_frequency + self._next_dirt_spawn = self.dirt_properties.spawn_frequency obs = self._return_state() return obs diff --git a/environments/logging/plotting.py b/environments/logging/plotting.py index 8c20622..91585a2 100644 --- a/environments/logging/plotting.py +++ b/environments/logging/plotting.py @@ -32,8 +32,8 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None) hue_order = sorted(list(df[hue].unique())) try: sns.set(rc={'text.usetex': True}, style='whitegrid') - sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, - hue_order=hue_order, hue=hue, style=style) + _ = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, + hue_order=hue_order, hue=hue, style=style) plot(filepath, ext=ext) # plot raises errors not lineplot! except (FileNotFoundError, RuntimeError): print('Struggling to plot Figure using LaTeX - going back to normal.') diff --git a/main.py b/main.py index 039eeb5..87d808b 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,10 @@ from typing import Union, List from os import PathLike from pathlib import Path import time + +import numpy as np import pandas as pd +from gym.wrappers import FrameStack from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv @@ -50,7 +53,7 @@ def combine_runs(run_path: Union[str, PathLike]): def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]): run_path = Path(run_path) df_list = list() - parameter = list(parameter) if isinstance(parameter, str) else parameter + parameter = [parameter] if isinstance(parameter, str) else parameter for path in run_path.iterdir(): if path.is_dir() and str(run_identifier) in path.name: for run, monitor_file in enumerate(path.rglob('monitor_*.pick')): @@ -83,29 +86,36 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': + # compare_runs(Path('debug_out'), 1622650432, 'step_reward') + # exit() + from stable_baselines3 import PPO, DQN, A2C from algorithms.dqn_reg import RegDQN + # from sb3_contrib import QRDQN dirt_props = DirtProperties() time_stamp = int(time.time()) out_path = None - for modeL_type in [PPO, A2C, RegDQN, DQN]: - for seed in range(5): + # for modeL_type in [PPO, A2C, RegDQN, DQN]: + modeL_type = PPO + for coef in [0.01, 0.1, 0.25]: + for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, allow_diagonal_movement=True, allow_no_op=False, verbose=False, omit_agent_slice_in_obs=True) + env.save_params(Path('debug_out', 'yaml.txt')) - vec_wrap = DummyVecEnv([lambda: env for _ in range(4)]) - stack_wrap = VecFrameStack(vec_wrap, n_stack=4, channels_order='first') + # env = FrameStack(env, 4) model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' - identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' + identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}' out_path /= identifier callbacks = CallbackList( diff --git a/reload_agent.py b/reload_agent.py index cb88e19..16155eb 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -14,7 +14,7 @@ warnings.filterwarnings('ignore', category=UserWarning) if __name__ == '__main__': - model_name = 'A2C_1622571986' + model_name = 'A2C_1622650432' run_id = 0 out_path = Path(__file__).parent / 'debug_out' model_path = out_path / model_name