yaml working

This commit is contained in:
steffen-illium 2021-06-04 18:11:23 +02:00
parent 8ce92d5db4
commit f1306d4d6f
5 changed files with 41 additions and 28 deletions

View File

@ -1,4 +1,4 @@
import pickle from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import List, Union, Iterable, NamedTuple from typing import List, Union, Iterable, NamedTuple
@ -89,7 +89,8 @@ class Actions(Register):
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
self.allow_square_movement = movement_properties.allow_square_movement self.allow_square_movement = movement_properties.allow_square_movement
# FIXME: There is a bug in helpers because there actions are ints. and the order matters. # FIXME: There is a bug in helpers because there actions are ints. and the order matters.
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!" assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
"There is a bug in helpers!!!"
super(Actions, self).__init__() super(Actions, self).__init__()
if self.allow_square_movement: if self.allow_square_movement:
@ -109,11 +110,14 @@ class StateSlice(Register):
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]]) self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
# noinspection PyAttributeOutsideInit
class BaseFactory(gym.Env): class BaseFactory(gym.Env):
# def __setattr__(self, key, value): def __setattr__(self, key, value):
# if isinstance(value, dict): if isinstance(value, dict):
super(BaseFactory, self).__setattr__(key, Namespace(**value))
else:
super(BaseFactory, self).__setattr__(key, value)
@property @property
def action_space(self): def action_space(self):
@ -199,8 +203,8 @@ class BaseFactory(gym.Env):
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1) obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
obs_padded[:, obs_padded[:,
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
obs = obs_padded obs = obs_padded
else: else:
obs = self._state obs = self._state
@ -323,9 +327,11 @@ class BaseFactory(gym.Env):
raise NotImplementedError raise NotImplementedError
def save_params(self, filepath: Path): def save_params(self, filepath: Path):
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')} # noinspection PyProtectedMember
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
if not key.startswith('_') and not key.startswith('__')}
filepath.parent.mkdir(parents=True, exist_ok=True) filepath.parent.mkdir(parents=True, exist_ok=True)
with filepath.open('wb') as f: with filepath.open('w') as f:
# yaml.dump(d, f) yaml.dump(d, f)
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) # pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)

View File

@ -185,17 +185,21 @@ class SimpleFactory(BaseFactory):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = True
import yaml
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
env_kwargs = yaml.load(f)
factory = SimpleFactory(**env_kwargs)
dirt_props = DirtProperties() # dirt_props = DirtProperties()
move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False) # move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms', # factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
pomdp_radius=2) # pomdp_radius=2)
n_actions = factory.action_space.n - 1 n_actions = factory.action_space.n - 1
for epoch in range(100): for epoch in range(100):
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)] random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
env_state, this_reward, done_bool, _ = factory.reset() env_state = factory.reset()
for agent_i_action in random_actions: for agent_i_action in random_actions:
env_state, reward, done_bool, info_obj = factory.step(agent_i_action) env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
if render: if render:

View File

@ -5,12 +5,10 @@ from os import PathLike
from pathlib import Path from pathlib import Path
import time import time
import numpy as np
import pandas as pd import pandas as pd
from gym.wrappers import FrameStack from gym.wrappers import FrameStack
from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
from environments.factory.base_factory import MovementProperties from environments.factory.base_factory import MovementProperties
from environments.factory.simple_factory import DirtProperties, SimpleFactory from environments.factory.simple_factory import DirtProperties, SimpleFactory
@ -102,13 +100,12 @@ if __name__ == '__main__':
out_path = None out_path = None
for modeL_type in [PPO, A2C, RegDQN, DQN]: for modeL_type in [PPO]: # , A2C, RegDQN, DQN]:
for seed in range(3): for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
movement_properties=move_props, level='rooms', movement_properties=move_props, level='rooms',
omit_agent_slice_in_obs=True) omit_agent_slice_in_obs=True)
env.save_params(Path('debug_out', 'yaml.txt'))
# env = FrameStack(env, 4) # env = FrameStack(env, 4)
@ -125,7 +122,7 @@ if __name__ == '__main__':
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
) )
model.learn(total_timesteps=int(2e5), callback=callbacks) model.learn(total_timesteps=int(5e5), callback=callbacks)
save_path = out_path / f'model_{identifier}.zip' save_path = out_path / f'model_{identifier}.zip'
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)

View File

@ -1,7 +1,7 @@
import pickle
import warnings import warnings
from pathlib import Path from pathlib import Path
import yaml
from natsort import natsorted from natsort import natsorted
from stable_baselines3 import PPO from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.evaluation import evaluate_policy
@ -19,9 +19,9 @@ if __name__ == '__main__':
out_path = Path(__file__).parent / 'debug_out' out_path = Path(__file__).parent / 'debug_out'
model_path = out_path / model_name model_path = out_path / model_name
with (model_path / f'env_{model_name}.pick').open('rb') as f: with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
env_kwargs = pickle.load(f) env_kwargs = yaml.load(f)
env = SimpleFactory( **env_kwargs) env = SimpleFactory(**env_kwargs)
# Edit THIS: # Edit THIS:
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip'))) model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip')))

View File

@ -6,8 +6,8 @@ filelock==3.0.12
kiwisolver==1.3.1 kiwisolver==1.3.1
matplotlib==3.4.1 matplotlib==3.4.1
numpy==1.20.2 numpy==1.20.2
pandas==1.2.4 pandas~=1.2.3
pygame==2.0.0 pygame~=2.0.1
pep517==0.10.0 pep517==0.10.0
Pillow==8.2.0 Pillow==8.2.0
pyparsing==2.4.7 pyparsing==2.4.7
@ -22,3 +22,9 @@ torchaudio==0.8.1
torchvision==0.9.1 torchvision==0.9.1
typing-extensions==3.10.0.0 typing-extensions==3.10.0.0
virtualenv==20.4.6 virtualenv==20.4.6
gym~=0.18.0
PyYAML~=5.3.1
pyglet~=1.5.0
optuna~=2.7.0
natsort~=7.1.1