Merge remote-tracking branch 'origin/rl_plotting' into marl_refactor

This commit is contained in:
Chanumask
2024-05-21 11:39:51 +02:00
31 changed files with 530 additions and 134 deletions

View File

@@ -18,6 +18,7 @@ from collections import deque
from marl_factory_grid.environment.actions import Noop
from marl_factory_grid.modules import Clean, DoorUse
from marl_factory_grid.utils.plotting.plot_single_runs import plot_action_maps
class Names:
@@ -583,8 +584,7 @@ class A2C:
if self.cfg[nms.ENV]["save_and_log"]:
self.create_info_maps(env, used_actions)
self.save_agent_models()
plot_action_maps(env, [self], self.results_path)
@torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False):