frame stack

This commit is contained in:
steffen-illium
2021-06-04 12:04:24 +02:00
parent b72013407e
commit 5668f5cb82
5 changed files with 36 additions and 20 deletions

View File

@ -6,6 +6,8 @@ import gym
import numpy as np import numpy as np
from gym import spaces from gym import spaces
import yaml
from environments import helpers as h from environments import helpers as h
@ -191,6 +193,7 @@ class BaseFactory(gym.Env):
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
obs = obs_padded obs = obs_padded
else: else:
assert not self.omit_agent_slice_in_obs
obs = self._state obs = self._state
if self.omit_agent_slice_in_obs: if self.omit_agent_slice_in_obs:
if obs.shape != (3, 5, 5): if obs.shape != (3, 5, 5):
@ -315,7 +318,9 @@ class BaseFactory(gym.Env):
raise NotImplementedError raise NotImplementedError
def save_params(self, filepath: Path): def save_params(self, filepath: Path):
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')} d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
filepath.parent.mkdir(parents=True, exist_ok=True) filepath.parent.mkdir(parents=True, exist_ok=True)
with filepath.open('wb') as f: with filepath.open('wb') as f:
# yaml.dump(d, f)
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)

View File

@ -14,14 +14,15 @@ from environments.factory.renderer import Renderer, Entity
DIRT_INDEX = -1 DIRT_INDEX = -1
CLEAN_UP_ACTION = 'clean_up' CLEAN_UP_ACTION = 'clean_up'
@dataclass @dataclass
class DirtProperties: class DirtProperties:
clean_amount = 2 # How much does the robot clean with one action. clean_amount: int = 2 # How much does the robot clean with one action.
max_spawn_ratio = 0.2 # On max how much tiles does the dirt spawn in percent. max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent.
gain_amount = 0.5 # How much dirt does spawn per tile gain_amount: float = 0.5 # How much dirt does spawn per tile
spawn_frequency = 5 # Spawn Frequency in Steps spawn_frequency: int = 5 # Spawn Frequency in Steps
max_local_amount = 1 # Max dirt amount per tile. max_local_amount: int = 1 # Max dirt amount per tile.
max_global_amount = 20 # Max dirt amount in the whole environment. max_global_amount: int = 20 # Max dirt amount in the whole environment.
class SimpleFactory(BaseFactory): class SimpleFactory(BaseFactory):
@ -93,11 +94,11 @@ class SimpleFactory(BaseFactory):
def step(self, actions): def step(self, actions):
_, r, done, info = super(SimpleFactory, self).step(actions) _, r, done, info = super(SimpleFactory, self).step(actions)
if not self.next_dirt_spawn: if not self._next_dirt_spawn:
self.spawn_dirt() self.spawn_dirt()
self.next_dirt_spawn = self.dirt_properties.spawn_frequency self._next_dirt_spawn = self.dirt_properties.spawn_frequency
else: else:
self.next_dirt_spawn -= 1 self._next_dirt_spawn -= 1
obs = self._return_state() obs = self._return_state()
return obs, r, done, info return obs, r, done, info
@ -117,7 +118,7 @@ class SimpleFactory(BaseFactory):
dirt_slice = np.zeros((1, *self._state.shape[1:])) dirt_slice = np.zeros((1, *self._state.shape[1:]))
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
self.spawn_dirt() self.spawn_dirt()
self.next_dirt_spawn = self.dirt_properties.spawn_frequency self._next_dirt_spawn = self.dirt_properties.spawn_frequency
obs = self._return_state() obs = self._return_state()
return obs return obs

View File

@ -32,8 +32,8 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
hue_order = sorted(list(df[hue].unique())) hue_order = sorted(list(df[hue].unique()))
try: try:
sns.set(rc={'text.usetex': True}, style='whitegrid') sns.set(rc={'text.usetex': True}, style='whitegrid')
sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE, _ = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
hue_order=hue_order, hue=hue, style=style) hue_order=hue_order, hue=hue, style=style)
plot(filepath, ext=ext) # plot raises errors not lineplot! plot(filepath, ext=ext) # plot raises errors not lineplot!
except (FileNotFoundError, RuntimeError): except (FileNotFoundError, RuntimeError):
print('Struggling to plot Figure using LaTeX - going back to normal.') print('Struggling to plot Figure using LaTeX - going back to normal.')

22
main.py
View File

@ -4,7 +4,10 @@ from typing import Union, List
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
import time import time
import numpy as np
import pandas as pd import pandas as pd
from gym.wrappers import FrameStack
from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
@ -50,7 +53,7 @@ def combine_runs(run_path: Union[str, PathLike]):
def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]): def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]):
run_path = Path(run_path) run_path = Path(run_path)
df_list = list() df_list = list()
parameter = list(parameter) if isinstance(parameter, str) else parameter parameter = [parameter] if isinstance(parameter, str) else parameter
for path in run_path.iterdir(): for path in run_path.iterdir():
if path.is_dir() and str(run_identifier) in path.name: if path.is_dir() and str(run_identifier) in path.name:
for run, monitor_file in enumerate(path.rglob('monitor_*.pick')): for run, monitor_file in enumerate(path.rglob('monitor_*.pick')):
@ -83,29 +86,36 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
if __name__ == '__main__': if __name__ == '__main__':
# compare_runs(Path('debug_out'), 1622650432, 'step_reward')
# exit()
from stable_baselines3 import PPO, DQN, A2C from stable_baselines3 import PPO, DQN, A2C
from algorithms.dqn_reg import RegDQN from algorithms.dqn_reg import RegDQN
# from sb3_contrib import QRDQN
dirt_props = DirtProperties() dirt_props = DirtProperties()
time_stamp = int(time.time()) time_stamp = int(time.time())
out_path = None out_path = None
for modeL_type in [PPO, A2C, RegDQN, DQN]: # for modeL_type in [PPO, A2C, RegDQN, DQN]:
for seed in range(5): modeL_type = PPO
for coef in [0.01, 0.1, 0.25]:
for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
allow_diagonal_movement=True, allow_no_op=False, verbose=False, allow_diagonal_movement=True, allow_no_op=False, verbose=False,
omit_agent_slice_in_obs=True) omit_agent_slice_in_obs=True)
env.save_params(Path('debug_out', 'yaml.txt'))
vec_wrap = DummyVecEnv([lambda: env for _ in range(4)]) # env = FrameStack(env, 4)
stack_wrap = VecFrameStack(vec_wrap, n_stack=4, channels_order='first')
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu') model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}'
out_path /= identifier out_path /= identifier
callbacks = CallbackList( callbacks = CallbackList(

View File

@ -14,7 +14,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__': if __name__ == '__main__':
model_name = 'A2C_1622571986' model_name = 'A2C_1622650432'
run_id = 0 run_id = 0
out_path = Path(__file__).parent / 'debug_out' out_path = Path(__file__).parent / 'debug_out'
model_path = out_path / model_name model_path = out_path / model_name