diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 00cb173..565f5a6 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -138,11 +138,12 @@ class BaseFactory(gym.Env): def movement_actions(self): return self._actions.movement_actions - def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, + def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None, movement_properties: MovementProperties = MovementProperties(), omit_agent_slice_in_obs=False, **kwargs): self.movement_properties = movement_properties + self.level_name = level_name self.n_agents = n_agents self.max_steps = max_steps @@ -154,7 +155,7 @@ class BaseFactory(gym.Env): self._actions = _actions + self.additional_actions self._level = h.one_hot_level( - h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt') + h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt') ) self._state_slices = StateSlice(n_agents) if 'additional_slices' in kwargs: @@ -328,8 +329,8 @@ class BaseFactory(gym.Env): def save_params(self, filepath: Path): # noinspection PyProtectedMember - d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items() - if not key.startswith('_') and not key.startswith('__')} + # d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items() + d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')} filepath.parent.mkdir(parents=True, exist_ok=True) with filepath.open('w') as f: diff --git a/main.py b/main.py index ff88543..4a983fd 100644 --- a/main.py +++ b/main.py @@ -60,7 +60,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List monitor_df = pickle.load(f) monitor_df['run'] = run - monitor_df['model'] = path.name.split('_')[1] + monitor_df['model'] = path.name.split('_')[0] monitor_df = monitor_df.fillna(0) df_list.append(monitor_df) @@ -85,8 +85,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': - # compare_runs(Path('debug_out') / 'PPO_1622800949', 1622800949, 'step_reward') - # exit() + compare_runs(Path('debug_out'), 1623052687, ['agent_0_vs_level']) + exit() from stable_baselines3 import PPO, DQN, A2C from algorithms.reg_dqn import RegDQN @@ -104,7 +104,7 @@ if __name__ == '__main__': for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400, - movement_properties=move_props, level='rooms', + movement_properties=move_props, level_name='rooms', omit_agent_slice_in_obs=True) # env = FrameStack(env, 4) @@ -112,10 +112,10 @@ if __name__ == '__main__': kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {} model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) - out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}' + out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' - identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}' + identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' out_path /= identifier callbacks = CallbackList( @@ -127,7 +127,7 @@ if __name__ == '__main__': save_path = out_path / f'model_{identifier}.zip' save_path.parent.mkdir(parents=True, exist_ok=True) model.save(save_path) - env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.pick') + env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml') if out_path: combine_runs(out_path.parent) diff --git a/reload_agent.py b/reload_agent.py index 0f2c46e..cfa3383 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -14,21 +14,21 @@ warnings.filterwarnings('ignore', category=UserWarning) if __name__ == '__main__': - model_name = 'A2C_1622650432' + model_name = 'PPO_1623052687' run_id = 0 out_path = Path(__file__).parent / 'debug_out' model_path = out_path / model_name - with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f: - env_kwargs = yaml.load(f) - env = SimpleFactory(**env_kwargs) + with (model_path / f'env_{model_name}.yaml').open('r') as f: + env_kwargs = yaml.load(f, Loader=yaml.FullLoader) + env = SimpleFactory(level_name='rooms', **env_kwargs) # Edit THIS: - model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip'))) + model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip'))) this_model = model_files[0] model = PPO.load(this_model) - evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True) + evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True) print(evaluation_result) env.close()