correct plotting an reloading

This commit is contained in:
steffen-illium 2021-06-07 13:08:41 +02:00
parent 4862407526
commit dbfa97aaba
3 changed files with 18 additions and 17 deletions

View File

@ -138,11 +138,12 @@ class BaseFactory(gym.Env):
def movement_actions(self): def movement_actions(self):
return self._actions.movement_actions return self._actions.movement_actions
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
movement_properties: MovementProperties = MovementProperties(), movement_properties: MovementProperties = MovementProperties(),
omit_agent_slice_in_obs=False, **kwargs): omit_agent_slice_in_obs=False, **kwargs):
self.movement_properties = movement_properties self.movement_properties = movement_properties
self.level_name = level_name
self.n_agents = n_agents self.n_agents = n_agents
self.max_steps = max_steps self.max_steps = max_steps
@ -154,7 +155,7 @@ class BaseFactory(gym.Env):
self._actions = _actions + self.additional_actions self._actions = _actions + self.additional_actions
self._level = h.one_hot_level( self._level = h.one_hot_level(
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt')
) )
self._state_slices = StateSlice(n_agents) self._state_slices = StateSlice(n_agents)
if 'additional_slices' in kwargs: if 'additional_slices' in kwargs:
@ -328,8 +329,8 @@ class BaseFactory(gym.Env):
def save_params(self, filepath: Path): def save_params(self, filepath: Path):
# noinspection PyProtectedMember # noinspection PyProtectedMember
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items() # d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
if not key.startswith('_') and not key.startswith('__')} d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
filepath.parent.mkdir(parents=True, exist_ok=True) filepath.parent.mkdir(parents=True, exist_ok=True)
with filepath.open('w') as f: with filepath.open('w') as f:

14
main.py
View File

@ -60,7 +60,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
monitor_df = pickle.load(f) monitor_df = pickle.load(f)
monitor_df['run'] = run monitor_df['run'] = run
monitor_df['model'] = path.name.split('_')[1] monitor_df['model'] = path.name.split('_')[0]
monitor_df = monitor_df.fillna(0) monitor_df = monitor_df.fillna(0)
df_list.append(monitor_df) df_list.append(monitor_df)
@ -85,8 +85,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
if __name__ == '__main__': if __name__ == '__main__':
# compare_runs(Path('debug_out') / 'PPO_1622800949', 1622800949, 'step_reward') compare_runs(Path('debug_out'), 1623052687, ['agent_0_vs_level'])
# exit() exit()
from stable_baselines3 import PPO, DQN, A2C from stable_baselines3 import PPO, DQN, A2C
from algorithms.reg_dqn import RegDQN from algorithms.reg_dqn import RegDQN
@ -104,7 +104,7 @@ if __name__ == '__main__':
for seed in range(3): for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400, env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
movement_properties=move_props, level='rooms', movement_properties=move_props, level_name='rooms',
omit_agent_slice_in_obs=True) omit_agent_slice_in_obs=True)
# env = FrameStack(env, 4) # env = FrameStack(env, 4)
@ -112,10 +112,10 @@ if __name__ == '__main__':
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {} kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}' out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}' identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
out_path /= identifier out_path /= identifier
callbacks = CallbackList( callbacks = CallbackList(
@ -127,7 +127,7 @@ if __name__ == '__main__':
save_path = out_path / f'model_{identifier}.zip' save_path = out_path / f'model_{identifier}.zip'
save_path.parent.mkdir(parents=True, exist_ok=True) save_path.parent.mkdir(parents=True, exist_ok=True)
model.save(save_path) model.save(save_path)
env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.pick') env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
if out_path: if out_path:
combine_runs(out_path.parent) combine_runs(out_path.parent)

View File

@ -14,21 +14,21 @@ warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__': if __name__ == '__main__':
model_name = 'A2C_1622650432' model_name = 'PPO_1623052687'
run_id = 0 run_id = 0
out_path = Path(__file__).parent / 'debug_out' out_path = Path(__file__).parent / 'debug_out'
model_path = out_path / model_name model_path = out_path / model_name
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f: with (model_path / f'env_{model_name}.yaml').open('r') as f:
env_kwargs = yaml.load(f) env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
env = SimpleFactory(**env_kwargs) env = SimpleFactory(level_name='rooms', **env_kwargs)
# Edit THIS: # Edit THIS:
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip'))) model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
this_model = model_files[0] this_model = model_files[0]
model = PPO.load(this_model) model = PPO.load(this_model)
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True) evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True)
print(evaluation_result) print(evaluation_result)
env.close() env.close()