mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
zwischenstand, no checkout pls!
This commit is contained in:
parent
1b98171f3a
commit
36fe59c95c
@ -1,4 +1,4 @@
|
||||
from typing import List, Union, Iterable
|
||||
from typing import List, Union, Iterable, TypedDict
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
@ -32,6 +32,37 @@ class AgentState:
|
||||
raise AttributeError(f'"{key}" cannot be updated, this attr is not a part of {self.__class__.__name__}')
|
||||
|
||||
|
||||
class Actions:
|
||||
|
||||
def __init__(self, allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=True):
|
||||
self.allow_no_OP = allow_no_OP
|
||||
self.allow_diagonal_movement = allow_diagonal_movement
|
||||
self.allow_square_movement = allow_square_movement
|
||||
self._registerd_actions = dict()
|
||||
if allow_square_movement:
|
||||
self + {key: val for key, val in enumerate(['north', 'east', 'south', 'west'])}
|
||||
if allow_diagonal_movement:
|
||||
self + {key: val for key, val in enumerate(['north-east', 'south-east', 'south-west', 'north-west'])}
|
||||
|
||||
self._movement_actions = self._registerd_actions.copy()
|
||||
if self.allow_no_OP:
|
||||
self + {0:'no-op'}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self._registerd_actions)
|
||||
|
||||
def __add__(self, other: dict):
|
||||
assert all([isinstance(x, int) for x in other.keys()]), f'All action keys have to be of type {int}.'
|
||||
assert all([isinstance(x, str) for x in other.values()]), f'All action values have to be of type {str}.'
|
||||
self._registerd_actions.update({key+len(self._registerd_actions): value for key,value in other.items()})
|
||||
return self
|
||||
|
||||
def register_additional_actions(self, other:dict):
|
||||
self_with_additional_actions = self + other
|
||||
return self_with_additional_actions
|
||||
|
||||
|
||||
class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
@ -44,7 +75,16 @@ class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def movement_actions(self):
|
||||
return (int(self.allow_square_movement) + int(self.allow_diagonal_movement)) * 4
|
||||
if self._movement_actions is None:
|
||||
self._movement_actions = dict()
|
||||
if self.allow_square_movement:
|
||||
self._movement_actions.update(
|
||||
)
|
||||
if self.allow_diagonal_movement:
|
||||
self.{key: val for key, val in zip(range(4), ['ne', 'ne', 'nw', 'nw'])}
|
||||
|
||||
return self._movement_actions
|
||||
|
||||
|
||||
@property
|
||||
def string_slices(self):
|
||||
@ -53,18 +93,17 @@ class BaseFactory(gym.Env):
|
||||
def __init__(self, level='simple', n_agents=1, max_steps=int(2e2)):
|
||||
self.n_agents = n_agents
|
||||
self.max_steps = max_steps
|
||||
self.allow_square_movement = True
|
||||
self.allow_diagonal_movement = True
|
||||
self.allow_no_OP = True
|
||||
self.done_at_collision = False
|
||||
self._registered_actions = self.movement_actions + int(self.allow_no_OP) + self.register_additional_actions()
|
||||
self._actions = Actions(allow_square_movement=True, allow_diagonal_movement=True, allow_no_OP=False)
|
||||
|
||||
|
||||
self.level = h.one_hot_level(
|
||||
h.parse_level(Path(__file__).parent / h.LEVELS_DIR / f'{level}.txt')
|
||||
)
|
||||
self.slice_strings = {0: 'level', **{i: f'agent#{i}' for i in range(1, self.n_agents+1)}}
|
||||
self.reset()
|
||||
|
||||
def register_additional_actions(self) -> int:
|
||||
def register_additional_actions(self) -> dict:
|
||||
raise NotImplementedError('Please register additional actions ')
|
||||
|
||||
def reset(self) -> (np.ndarray, int, bool, dict):
|
||||
@ -123,10 +162,10 @@ class BaseFactory(gym.Env):
|
||||
return self.state, reward, done, info
|
||||
|
||||
def _is_moving_action(self, action):
|
||||
return action < self.movement_actions
|
||||
return self._registered_actions[action] in self.movement_actions
|
||||
|
||||
def _is_no_op(self, action):
|
||||
return self.allow_no_OP and (action - self.movement_actions) == 0
|
||||
return self._registered_actions[action] == 'no-op'
|
||||
|
||||
def check_all_collisions(self, agent_states: List[AgentState], collisions: int) -> np.ndarray:
|
||||
collision_vecs = np.zeros((len(agent_states), collisions)) # n_agents x n_slices
|
||||
|
@ -30,9 +30,8 @@ def plot(filepath, ext='png', tag='monitor', **kwargs):
|
||||
|
||||
|
||||
def prepare_plot(filepath, results_df, ext='png', tag=''):
|
||||
# %%
|
||||
|
||||
_ = sns.lineplot(data=results_df, ci='sd', x='step')
|
||||
_ = sns.lineplot(data=results_df, x='Episode', y='Score', hue='Measurement', ci='sd')
|
||||
|
||||
# %%
|
||||
sns.set_theme(palette=PALETTE, style='whitegrid')
|
||||
|
38
main.py
38
main.py
@ -5,6 +5,7 @@ from os import PathLike
|
||||
from pathlib import Path
|
||||
import time
|
||||
import pandas as pd
|
||||
from natsort import natsorted
|
||||
|
||||
from stable_baselines3.common.callbacks import CallbackList
|
||||
|
||||
@ -25,8 +26,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
||||
monitor_list = pickle.load(f)
|
||||
|
||||
for m_idx in range(len(monitor_list)):
|
||||
monitor_list[m_idx]['episode'] = str(m_idx)
|
||||
monitor_list[m_idx]['run'] = str(run)
|
||||
monitor_list[m_idx]['episode'] = m_idx
|
||||
monitor_list[m_idx]['run'] = run
|
||||
|
||||
df = pd.concat(monitor_list, ignore_index=True)
|
||||
df['train_step'] = range(df.shape[0])
|
||||
@ -42,31 +43,30 @@ def combine_runs(run_path: Union[str, PathLike]):
|
||||
|
||||
df_list.append(df)
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df = df.fillna(0)
|
||||
df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
|
||||
|
||||
df_group = df.groupby(['episode', 'run']).aggregate({col: 'mean' if col in ['dirt_amount',
|
||||
df_group = df.groupby(['Episode', 'Run']).aggregate({col: 'mean' if col in ['dirt_amount',
|
||||
'dirty_tiles'] else 'sum'
|
||||
for col in df.columns if col not in ['episode', 'run']
|
||||
}).reset_index()
|
||||
for col in df.columns if
|
||||
col not in ['Episode', 'Run', 'train_step']
|
||||
})
|
||||
|
||||
non_overlapp_window = df_group.groupby(['Run', (df_group.index.get_level_values('Episode') // 50)]).mean()
|
||||
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
df_melted = df_group.melt(id_vars=['episode', 'run'],
|
||||
value_vars=['agent_0_vs_level', 'dirt_amount',
|
||||
'dirty_tiles', 'step_reward',
|
||||
'failed_cleanup_attempt',
|
||||
'dirt_cleaned'], var_name="Variable",
|
||||
value_name="Score")
|
||||
df_melted = non_overlapp_window.reset_index().melt(id_vars=['Episode', 'Run'],
|
||||
value_vars=['agent_0_vs_level', 'dirt_amount',
|
||||
'dirty_tiles', 'step_reward',
|
||||
'failed_cleanup_attempt',
|
||||
'dirt_cleaned'], var_name="Measurement",
|
||||
value_name="Score")
|
||||
|
||||
sns.lineplot(data=df_melted, x='episode', y='Score', hue='Variable', ci='sd')
|
||||
plt.show()
|
||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
combine_runs('debug_out/PPO_1622120377')
|
||||
combine_runs('debug_out/PPO_1622128912')
|
||||
exit()
|
||||
|
||||
from stable_baselines3 import DQN, PPO
|
||||
@ -82,7 +82,7 @@ if __name__ == '__main__':
|
||||
|
||||
model = PPO("MlpPolicy", env, verbose=1, ent_coef=0.0, seed=seed, device='cpu')
|
||||
|
||||
out_path = Path('../debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||
|
||||
identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||
out_path /= identifier
|
||||
@ -92,7 +92,7 @@ if __name__ == '__main__':
|
||||
MonitorCallback(env, filepath=out_path / f'monitor_{identifier}.pick', plotting=False)]
|
||||
)
|
||||
|
||||
model.learn(total_timesteps=int(5e5), callback=callbacks)
|
||||
model.learn(total_timesteps=int(2e6), callback=callbacks)
|
||||
|
||||
save_path = out_path / f'model_{identifier}.zip'
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
32
reload_agent.py
Normal file
32
reload_agent.py
Normal file
@ -0,0 +1,32 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
from natsort import natsorted
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.callbacks import CallbackList
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from environments.factory.simple_factory import DirtProperties, SimpleFactory
|
||||
from environments.logging.monitor import MonitorCallback
|
||||
from environments.logging.training import TraningMonitor
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dirt_props = DirtProperties()
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props)
|
||||
|
||||
out_path = Path('debug_out')
|
||||
model_files = list(natsorted(out_path.rglob('*.zip')))
|
||||
this_model = model_files[0]
|
||||
|
||||
model = PPO.load(this_model)
|
||||
evaluation_result = evaluate_policy(model, env, n_eval_episodes=100, deterministic=False,
|
||||
render=True)
|
||||
print(evaluation_result)
|
||||
|
||||
env.close()
|
Loading…
x
Reference in New Issue
Block a user