diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 0256f81..bb0135a 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -1,4 +1,4 @@ -import pickle +from argparse import Namespace from pathlib import Path from typing import List, Union, Iterable, NamedTuple @@ -89,7 +89,8 @@ class Actions(Register): self.allow_diagonal_movement = movement_properties.allow_diagonal_movement self.allow_square_movement = movement_properties.allow_square_movement # FIXME: There is a bug in helpers because there actions are ints. and the order matters. - assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!" + assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \ + "There is a bug in helpers!!!" super(Actions, self).__init__() if self.allow_square_movement: @@ -109,11 +110,14 @@ class StateSlice(Register): self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]]) +# noinspection PyAttributeOutsideInit class BaseFactory(gym.Env): - # def __setattr__(self, key, value): - # if isinstance(value, dict): - + def __setattr__(self, key, value): + if isinstance(value, dict): + super(BaseFactory, self).__setattr__(key, Namespace(**value)) + else: + super(BaseFactory, self).__setattr__(key, value) @property def action_space(self): @@ -199,8 +203,8 @@ class BaseFactory(gym.Env): obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1) a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] obs_padded[:, - abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], - abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs + abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], + abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs obs = obs_padded else: obs = self._state @@ -323,9 +327,11 @@ class BaseFactory(gym.Env): raise NotImplementedError def save_params(self, filepath: Path): - d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')} + # noinspection PyProtectedMember + d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items() + if not key.startswith('_') and not key.startswith('__')} filepath.parent.mkdir(parents=True, exist_ok=True) - with filepath.open('wb') as f: - # yaml.dump(d, f) - pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) + with filepath.open('w') as f: + yaml.dump(d, f) + # pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) diff --git a/environments/factory/simple_factory.py b/environments/factory/simple_factory.py index 20969f5..c1f052d 100644 --- a/environments/factory/simple_factory.py +++ b/environments/factory/simple_factory.py @@ -185,17 +185,21 @@ class SimpleFactory(BaseFactory): if __name__ == '__main__': render = True + import yaml + with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f: + env_kwargs = yaml.load(f) + factory = SimpleFactory(**env_kwargs) - dirt_props = DirtProperties() - move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False) - factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms', - pomdp_radius=2) + # dirt_props = DirtProperties() + # move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False) + # factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms', + # pomdp_radius=2) n_actions = factory.action_space.n - 1 for epoch in range(100): - random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)] - env_state, this_reward, done_bool, _ = factory.reset() + random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)] + env_state = factory.reset() for agent_i_action in random_actions: env_state, reward, done_bool, info_obj = factory.step(agent_i_action) if render: diff --git a/main.py b/main.py index a9f8093..f05fd7f 100644 --- a/main.py +++ b/main.py @@ -5,12 +5,10 @@ from os import PathLike from pathlib import Path import time -import numpy as np import pandas as pd from gym.wrappers import FrameStack from stable_baselines3.common.callbacks import CallbackList -from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv from environments.factory.base_factory import MovementProperties from environments.factory.simple_factory import DirtProperties, SimpleFactory @@ -102,13 +100,12 @@ if __name__ == '__main__': out_path = None - for modeL_type in [PPO, A2C, RegDQN, DQN]: + for modeL_type in [PPO]: # , A2C, RegDQN, DQN]: for seed in range(3): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, movement_properties=move_props, level='rooms', omit_agent_slice_in_obs=True) - env.save_params(Path('debug_out', 'yaml.txt')) # env = FrameStack(env, 4) @@ -125,7 +122,7 @@ if __name__ == '__main__': [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)] ) - model.learn(total_timesteps=int(2e5), callback=callbacks) + model.learn(total_timesteps=int(5e5), callback=callbacks) save_path = out_path / f'model_{identifier}.zip' save_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/reload_agent.py b/reload_agent.py index 16155eb..0f2c46e 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -1,7 +1,7 @@ -import pickle import warnings from pathlib import Path +import yaml from natsort import natsorted from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy @@ -19,9 +19,9 @@ if __name__ == '__main__': out_path = Path(__file__).parent / 'debug_out' model_path = out_path / model_name - with (model_path / f'env_{model_name}.pick').open('rb') as f: - env_kwargs = pickle.load(f) - env = SimpleFactory( **env_kwargs) + with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f: + env_kwargs = yaml.load(f) + env = SimpleFactory(**env_kwargs) # Edit THIS: model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip'))) diff --git a/requirements.txt b/requirements.txt index c49ccfa..96f4391 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,8 @@ filelock==3.0.12 kiwisolver==1.3.1 matplotlib==3.4.1 numpy==1.20.2 -pandas==1.2.4 -pygame==2.0.0 +pandas~=1.2.3 +pygame~=2.0.1 pep517==0.10.0 Pillow==8.2.0 pyparsing==2.4.7 @@ -22,3 +22,9 @@ torchaudio==0.8.1 torchvision==0.9.1 typing-extensions==3.10.0.0 virtualenv==20.4.6 + +gym~=0.18.0 +PyYAML~=5.3.1 +pyglet~=1.5.0 +optuna~=2.7.0 +natsort~=7.1.1 \ No newline at end of file