frame stack
This commit is contained in:
@ -6,6 +6,8 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
from environments import helpers as h
|
from environments import helpers as h
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +193,7 @@ class BaseFactory(gym.Env):
|
|||||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||||
obs = obs_padded
|
obs = obs_padded
|
||||||
else:
|
else:
|
||||||
|
assert not self.omit_agent_slice_in_obs
|
||||||
obs = self._state
|
obs = self._state
|
||||||
if self.omit_agent_slice_in_obs:
|
if self.omit_agent_slice_in_obs:
|
||||||
if obs.shape != (3, 5, 5):
|
if obs.shape != (3, 5, 5):
|
||||||
@ -315,7 +318,9 @@ class BaseFactory(gym.Env):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save_params(self, filepath: Path):
|
def save_params(self, filepath: Path):
|
||||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') or not key.startswith('__')}
|
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with filepath.open('wb') as f:
|
with filepath.open('wb') as f:
|
||||||
|
# yaml.dump(d, f)
|
||||||
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
@ -14,14 +14,15 @@ from environments.factory.renderer import Renderer, Entity
|
|||||||
DIRT_INDEX = -1
|
DIRT_INDEX = -1
|
||||||
CLEAN_UP_ACTION = 'clean_up'
|
CLEAN_UP_ACTION = 'clean_up'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DirtProperties:
|
class DirtProperties:
|
||||||
clean_amount = 2 # How much does the robot clean with one action.
|
clean_amount: int = 2 # How much does the robot clean with one action.
|
||||||
max_spawn_ratio = 0.2 # On max how much tiles does the dirt spawn in percent.
|
max_spawn_ratio: float = 0.2 # On max how much tiles does the dirt spawn in percent.
|
||||||
gain_amount = 0.5 # How much dirt does spawn per tile
|
gain_amount: float = 0.5 # How much dirt does spawn per tile
|
||||||
spawn_frequency = 5 # Spawn Frequency in Steps
|
spawn_frequency: int = 5 # Spawn Frequency in Steps
|
||||||
max_local_amount = 1 # Max dirt amount per tile.
|
max_local_amount: int = 1 # Max dirt amount per tile.
|
||||||
max_global_amount = 20 # Max dirt amount in the whole environment.
|
max_global_amount: int = 20 # Max dirt amount in the whole environment.
|
||||||
|
|
||||||
|
|
||||||
class SimpleFactory(BaseFactory):
|
class SimpleFactory(BaseFactory):
|
||||||
@ -93,11 +94,11 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
_, r, done, info = super(SimpleFactory, self).step(actions)
|
_, r, done, info = super(SimpleFactory, self).step(actions)
|
||||||
if not self.next_dirt_spawn:
|
if not self._next_dirt_spawn:
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self.next_dirt_spawn = self.dirt_properties.spawn_frequency
|
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||||
else:
|
else:
|
||||||
self.next_dirt_spawn -= 1
|
self._next_dirt_spawn -= 1
|
||||||
obs = self._return_state()
|
obs = self._return_state()
|
||||||
return obs, r, done, info
|
return obs, r, done, info
|
||||||
|
|
||||||
@ -117,7 +118,7 @@ class SimpleFactory(BaseFactory):
|
|||||||
dirt_slice = np.zeros((1, *self._state.shape[1:]))
|
dirt_slice = np.zeros((1, *self._state.shape[1:]))
|
||||||
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
self._state = np.concatenate((self._state, dirt_slice)) # dirt is now the last slice
|
||||||
self.spawn_dirt()
|
self.spawn_dirt()
|
||||||
self.next_dirt_spawn = self.dirt_properties.spawn_frequency
|
self._next_dirt_spawn = self.dirt_properties.spawn_frequency
|
||||||
obs = self._return_state()
|
obs = self._return_state()
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ def prepare_plot(filepath, results_df, ext='png', hue='Measurement', style=None)
|
|||||||
hue_order = sorted(list(df[hue].unique()))
|
hue_order = sorted(list(df[hue].unique()))
|
||||||
try:
|
try:
|
||||||
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
sns.set(rc={'text.usetex': True}, style='whitegrid')
|
||||||
sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
_ = sns.lineplot(data=df, x='Episode', y='Score', ci=95, palette=PALETTE,
|
||||||
hue_order=hue_order, hue=hue, style=style)
|
hue_order=hue_order, hue=hue, style=style)
|
||||||
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
plot(filepath, ext=ext) # plot raises errors not lineplot!
|
||||||
except (FileNotFoundError, RuntimeError):
|
except (FileNotFoundError, RuntimeError):
|
||||||
|
22
main.py
22
main.py
@ -4,7 +4,10 @@ from typing import Union, List
|
|||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from gym.wrappers import FrameStack
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import CallbackList
|
from stable_baselines3.common.callbacks import CallbackList
|
||||||
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
|
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
|
||||||
@ -50,7 +53,7 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]):
|
def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]):
|
||||||
run_path = Path(run_path)
|
run_path = Path(run_path)
|
||||||
df_list = list()
|
df_list = list()
|
||||||
parameter = list(parameter) if isinstance(parameter, str) else parameter
|
parameter = [parameter] if isinstance(parameter, str) else parameter
|
||||||
for path in run_path.iterdir():
|
for path in run_path.iterdir():
|
||||||
if path.is_dir() and str(run_identifier) in path.name:
|
if path.is_dir() and str(run_identifier) in path.name:
|
||||||
for run, monitor_file in enumerate(path.rglob('monitor_*.pick')):
|
for run, monitor_file in enumerate(path.rglob('monitor_*.pick')):
|
||||||
@ -83,29 +86,36 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
# compare_runs(Path('debug_out'), 1622650432, 'step_reward')
|
||||||
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
from algorithms.dqn_reg import RegDQN
|
from algorithms.dqn_reg import RegDQN
|
||||||
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
# for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
||||||
for seed in range(5):
|
modeL_type = PPO
|
||||||
|
for coef in [0.01, 0.1, 0.25]:
|
||||||
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
||||||
allow_diagonal_movement=True, allow_no_op=False, verbose=False,
|
allow_diagonal_movement=True, allow_no_op=False, verbose=False,
|
||||||
omit_agent_slice_in_obs=True)
|
omit_agent_slice_in_obs=True)
|
||||||
|
env.save_params(Path('debug_out', 'yaml.txt'))
|
||||||
|
|
||||||
vec_wrap = DummyVecEnv([lambda: env for _ in range(4)])
|
# env = FrameStack(env, 4)
|
||||||
stack_wrap = VecFrameStack(vec_wrap, n_stack=4, channels_order='first')
|
|
||||||
|
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu')
|
||||||
|
|
||||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||||
|
|
||||||
identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||||
|
identifier = f'{seed}_{str(coef).replace(".", "")}_{time_stamp}'
|
||||||
out_path /= identifier
|
out_path /= identifier
|
||||||
|
|
||||||
callbacks = CallbackList(
|
callbacks = CallbackList(
|
||||||
|
@ -14,7 +14,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
model_name = 'A2C_1622571986'
|
model_name = 'A2C_1622650432'
|
||||||
run_id = 0
|
run_id = 0
|
||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
model_path = out_path / model_name
|
model_path = out_path / model_name
|
||||||
|
Reference in New Issue
Block a user