mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Merge remote-tracking branch 'origin/main' into main
This commit is contained in:
commit
1425f75643
@ -72,7 +72,8 @@ class MonitorCallback(BaseCallback):
|
|||||||
self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items()
|
self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items()
|
||||||
if key not in ['terminal_observation', 'episode']}
|
if key not in ['terminal_observation', 'episode']}
|
||||||
|
|
||||||
for env_idx, done in enumerate(self.locals.get('dones', [])):
|
for env_idx, done in list(enumerate(self.locals.get('dones', []))) + \
|
||||||
|
list(enumerate(self.locals.get('done', []))):
|
||||||
if done:
|
if done:
|
||||||
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
|
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
|
||||||
self._monitor_dict = dict()
|
self._monitor_dict = dict()
|
||||||
|
6
main.py
6
main.py
@ -41,8 +41,6 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
value_vars=columns, var_name="Measurement",
|
value_vars=columns, var_name="Measurement",
|
||||||
value_name="Score")
|
value_name="Score")
|
||||||
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
|
||||||
#df_melted['Episode'] = df_melted['Episode'] * skip_n # only needed for old version
|
|
||||||
|
|
||||||
|
|
||||||
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
|
||||||
print('Plotting done.')
|
print('Plotting done.')
|
||||||
@ -51,6 +49,8 @@ def combine_runs(run_path: Union[str, PathLike]):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
|
from algorithms.dqn_reg import RegDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
time_stamp = int(time.time())
|
time_stamp = int(time.time())
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ if __name__ == '__main__':
|
|||||||
combine_runs(Path(__file__).parent / 'debug_out'/ 'A2C_1622571986')
|
combine_runs(Path(__file__).parent / 'debug_out'/ 'A2C_1622571986')
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
for modeL_type in [A2C, PPO, DQN]:
|
for modeL_type in [RegDQN, DQN]:
|
||||||
for seed in range(5):
|
for seed in range(5):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user