mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 23:06:43 +02:00
yaml working
This commit is contained in:
parent
8ce92d5db4
commit
f1306d4d6f
@ -1,4 +1,4 @@
|
|||||||
import pickle
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Iterable, NamedTuple
|
from typing import List, Union, Iterable, NamedTuple
|
||||||
|
|
||||||
@ -89,7 +89,8 @@ class Actions(Register):
|
|||||||
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
self.allow_diagonal_movement = movement_properties.allow_diagonal_movement
|
||||||
self.allow_square_movement = movement_properties.allow_square_movement
|
self.allow_square_movement = movement_properties.allow_square_movement
|
||||||
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
# FIXME: There is a bug in helpers because there actions are ints. and the order matters.
|
||||||
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), "There is a bug in helpers!!!"
|
assert not(self.allow_square_movement is False and self.allow_diagonal_movement is True), \
|
||||||
|
"There is a bug in helpers!!!"
|
||||||
super(Actions, self).__init__()
|
super(Actions, self).__init__()
|
||||||
|
|
||||||
if self.allow_square_movement:
|
if self.allow_square_movement:
|
||||||
@ -109,11 +110,14 @@ class StateSlice(Register):
|
|||||||
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
self.register_additional_items(['level', *[f'agent#{i}' for i in range(offset, n_agents+offset)]])
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyAttributeOutsideInit
|
||||||
class BaseFactory(gym.Env):
|
class BaseFactory(gym.Env):
|
||||||
|
|
||||||
# def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
# if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
|
super(BaseFactory, self).__setattr__(key, Namespace(**value))
|
||||||
|
else:
|
||||||
|
super(BaseFactory, self).__setattr__(key, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_space(self):
|
def action_space(self):
|
||||||
@ -199,8 +203,8 @@ class BaseFactory(gym.Env):
|
|||||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
||||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||||
obs_padded[:,
|
obs_padded[:,
|
||||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||||
obs = obs_padded
|
obs = obs_padded
|
||||||
else:
|
else:
|
||||||
obs = self._state
|
obs = self._state
|
||||||
@ -323,9 +327,11 @@ class BaseFactory(gym.Env):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def save_params(self, filepath: Path):
|
def save_params(self, filepath: Path):
|
||||||
d = {key: val for key, val in self.__dict__.items() if not key.startswith('_') and not key.startswith('__')}
|
# noinspection PyProtectedMember
|
||||||
|
d = {key: val._asdict() if hasattr(val, '_asdict') else val for key, val in self.__dict__.items()
|
||||||
|
if not key.startswith('_') and not key.startswith('__')}
|
||||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with filepath.open('wb') as f:
|
with filepath.open('w') as f:
|
||||||
# yaml.dump(d, f)
|
yaml.dump(d, f)
|
||||||
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
# pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
@ -185,17 +185,21 @@ class SimpleFactory(BaseFactory):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
|
import yaml
|
||||||
|
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
|
||||||
|
env_kwargs = yaml.load(f)
|
||||||
|
factory = SimpleFactory(**env_kwargs)
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
# dirt_props = DirtProperties()
|
||||||
move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
# move_props = MovementProperties(allow_diagonal_movement=False, allow_no_op=False)
|
||||||
factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
|
# factory = SimpleFactory(n_agents=2, dirt_properties=dirt_props, movement_properties=move_props, level='rooms',
|
||||||
pomdp_radius=2)
|
# pomdp_radius=2)
|
||||||
|
|
||||||
n_actions = factory.action_space.n - 1
|
n_actions = factory.action_space.n - 1
|
||||||
|
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
random_actions = [(random.randint(0, n_actions), random.randint(0, n_actions)) for _ in range(200)]
|
random_actions = [[random.randint(0, n_actions) for _ in range(factory.n_agents)] for _ in range(200)]
|
||||||
env_state, this_reward, done_bool, _ = factory.reset()
|
env_state = factory.reset()
|
||||||
for agent_i_action in random_actions:
|
for agent_i_action in random_actions:
|
||||||
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
env_state, reward, done_bool, info_obj = factory.step(agent_i_action)
|
||||||
if render:
|
if render:
|
||||||
|
7
main.py
7
main.py
@ -5,12 +5,10 @@ from os import PathLike
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from gym.wrappers import FrameStack
|
from gym.wrappers import FrameStack
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import CallbackList
|
from stable_baselines3.common.callbacks import CallbackList
|
||||||
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv
|
|
||||||
|
|
||||||
from environments.factory.base_factory import MovementProperties
|
from environments.factory.base_factory import MovementProperties
|
||||||
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
||||||
@ -102,13 +100,12 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
for modeL_type in [PPO]: # , A2C, RegDQN, DQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
||||||
movement_properties=move_props, level='rooms',
|
movement_properties=move_props, level='rooms',
|
||||||
omit_agent_slice_in_obs=True)
|
omit_agent_slice_in_obs=True)
|
||||||
env.save_params(Path('debug_out', 'yaml.txt'))
|
|
||||||
|
|
||||||
# env = FrameStack(env, 4)
|
# env = FrameStack(env, 4)
|
||||||
|
|
||||||
@ -125,7 +122,7 @@ if __name__ == '__main__':
|
|||||||
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
[MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
||||||
)
|
)
|
||||||
|
|
||||||
model.learn(total_timesteps=int(2e5), callback=callbacks)
|
model.learn(total_timesteps=int(5e5), callback=callbacks)
|
||||||
|
|
||||||
save_path = out_path / f'model_{identifier}.zip'
|
save_path = out_path / f'model_{identifier}.zip'
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pickle
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
@ -19,9 +19,9 @@ if __name__ == '__main__':
|
|||||||
out_path = Path(__file__).parent / 'debug_out'
|
out_path = Path(__file__).parent / 'debug_out'
|
||||||
model_path = out_path / model_name
|
model_path = out_path / model_name
|
||||||
|
|
||||||
with (model_path / f'env_{model_name}.pick').open('rb') as f:
|
with Path(r'C:\Users\steff\projects\f_iks\debug_out\yaml.txt').open('r') as f:
|
||||||
env_kwargs = pickle.load(f)
|
env_kwargs = yaml.load(f)
|
||||||
env = SimpleFactory( **env_kwargs)
|
env = SimpleFactory(**env_kwargs)
|
||||||
|
|
||||||
# Edit THIS:
|
# Edit THIS:
|
||||||
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip')))
|
model_files = list(natsorted((model_path / f'{run_id}_{model_name}').rglob('*.zip')))
|
||||||
|
@ -6,8 +6,8 @@ filelock==3.0.12
|
|||||||
kiwisolver==1.3.1
|
kiwisolver==1.3.1
|
||||||
matplotlib==3.4.1
|
matplotlib==3.4.1
|
||||||
numpy==1.20.2
|
numpy==1.20.2
|
||||||
pandas==1.2.4
|
pandas~=1.2.3
|
||||||
pygame==2.0.0
|
pygame~=2.0.1
|
||||||
pep517==0.10.0
|
pep517==0.10.0
|
||||||
Pillow==8.2.0
|
Pillow==8.2.0
|
||||||
pyparsing==2.4.7
|
pyparsing==2.4.7
|
||||||
@ -22,3 +22,9 @@ torchaudio==0.8.1
|
|||||||
torchvision==0.9.1
|
torchvision==0.9.1
|
||||||
typing-extensions==3.10.0.0
|
typing-extensions==3.10.0.0
|
||||||
virtualenv==20.4.6
|
virtualenv==20.4.6
|
||||||
|
|
||||||
|
gym~=0.18.0
|
||||||
|
PyYAML~=5.3.1
|
||||||
|
pyglet~=1.5.0
|
||||||
|
optuna~=2.7.0
|
||||||
|
natsort~=7.1.1
|
Loading…
x
Reference in New Issue
Block a user