mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 23:06:43 +02:00
correct plotting an reloading
This commit is contained in:
parent
4862407526
commit
dbfa97aaba
@ -138,11 +138,12 @@ class BaseFactory(gym.Env):
|
||||
def movement_actions(self):
|
||||
return self._actions.movement_actions
|
||||
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
def __init__(self, level_name='simple', n_agents=1, max_steps=int(5e2), pomdp_radius: Union[None, int] = None,
|
||||
movement_properties: MovementProperties = MovementProperties(),
|
||||
omit_agent_slice_in_obs=False, **kwargs):
|
||||
|
||||
self.movement_properties = movement_properties
|
||||
self.level_name = level_name
|
||||
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
@ -154,7 +155,7 @@ class BaseFactory(gym.Env):
|
||||
self._actions = _actions + self.additional_actions
|
||||
|
||||
self._level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{self.level_name}.txt')
|
||||
)
|
||||
self._state_slices = StateSlice(n_agents)
|
||||
if 'additional_slices' in kwargs:
|
||||
@ -328,8 +329,8 @@ class BaseFactory(gym.Env):
|
||||
|
||||
def save_params(self, filepath: Path):
|
||||
# noinspection PyProtectedMember
|
||||
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
if not key.startswith('_') and not key.startswith('__')}
|
||||
# d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with filepath.open('w') as f:
|
||||
|
14
main.py
14
main.py
@ -60,7 +60,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
||||
monitor_df = pickle.load(f)
|
||||
|
||||
monitor_df['run'] = run
|
||||
monitor_df['model'] = path.name.split('_')[1]
|
||||
monitor_df['model'] = path.name.split('_')[0]
|
||||
monitor_df = monitor_df.fillna(0)
|
||||
df_list.append(monitor_df)
|
||||
|
||||
@ -85,8 +85,8 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# compare_runs(Path('debug_out') / 'PPO_1622800949', 1622800949, 'step_reward')
|
||||
# exit()
|
||||
compare_runs(Path('debug_out'), 1623052687, ['agent_0_vs_level'])
|
||||
exit()
|
||||
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
from algorithms.reg_dqn import RegDQN
|
||||
@ -104,7 +104,7 @@ if __name__ == '__main__':
|
||||
for seed in range(3):
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
|
||||
movement_properties=move_props, level='rooms',
|
||||
movement_properties=move_props, level_name='rooms',
|
||||
omit_agent_slice_in_obs=True)
|
||||
|
||||
# env = FrameStack(env, 4)
|
||||
@ -112,10 +112,10 @@ if __name__ == '__main__':
|
||||
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||
|
||||
out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}'
|
||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||
|
||||
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
|
||||
identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||
out_path /= identifier
|
||||
|
||||
callbacks = CallbackList(
|
||||
@ -127,7 +127,7 @@ if __name__ == '__main__':
|
||||
save_path = out_path / f'model_{identifier}.zip'
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
model.save(save_path)
|
||||
env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.pick')
|
||||
env.save_params(out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
|
||||
|
||||
if out_path:
|
||||
combine_runs(out_path.parent)
|
||||
|
@ -14,21 +14,21 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model_name = 'A2C_1622650432'
|
||||
model_name = 'PPO_1623052687'
|
||||
run_id = 0
|
||||
out_path = Path(__file__).parent / 'debug_out'
|
||||
model_path = out_path / model_name
|
||||
|
||||
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
|
||||
env_kwargs = yaml.load(f)
|
||||
env = SimpleFactory(**env_kwargs)
|
||||
with (model_path / f'env_{model_name}.yaml').open('r') as f:
|
||||
env_kwargs = yaml.load(f, Loader=yaml.FullLoader)
|
||||
env = SimpleFactory(level_name='rooms', **env_kwargs)
|
||||
|
||||
# Edit THIS:
|
||||
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip')))
|
||||
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('model_*.zip')))
|
||||
this_model = model_files[0]
|
||||
|
||||
model = PPO.load(this_model)
|
||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False, render=True)
|
||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=True, render=True)
|
||||
print(evaluation_result)
|
||||
|
||||
env.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user