Adapt base_ac.py and utils.py to be compatible with refactored environment

This commit is contained in:
Julian Schönberger
2024-03-27 17:04:14 +01:00
parent 1e4ec254f4
commit 086a921929
4 changed files with 66 additions and 23 deletions

View File

@ -3,6 +3,8 @@ from pathlib import Path
import numpy as np
import yaml
from marl_factory_grid import Factory
def load_class(classname):
from importlib import import_module
@ -55,9 +57,17 @@ def load_yaml_file(path: Path):
def add_env_props(cfg):
env = instantiate_class(cfg['environment'].copy())
cfg['agent'].update(dict(observation_size=list(env.observation_space.shape),
n_actions=env.action_space.n))
# Path to config File
env_path = Path(f'../marl_factory_grid/configs/{cfg["env"]["env_name"]}.yaml')
# Env Init
factory = Factory(env_path)
_ = factory.reset()
# Agent Init
cfg['agent'].update(dict(observation_size=list(factory.observation_space[0].shape),
n_actions=factory.action_space[0].n))
return factory
class Checkpointer(object):