mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
Adapt base_ac.py and utils.py to be compatible with refactored environment
This commit is contained in:
@ -3,6 +3,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid import Factory
|
||||
|
||||
|
||||
def load_class(classname):
|
||||
from importlib import import_module
|
||||
@ -55,9 +57,17 @@ def load_yaml_file(path: Path):
|
||||
|
||||
|
||||
def add_env_props(cfg):
|
||||
env = instantiate_class(cfg['environment'].copy())
|
||||
cfg['agent'].update(dict(observation_size=list(env.observation_space.shape),
|
||||
n_actions=env.action_space.n))
|
||||
# Path to config File
|
||||
env_path = Path(f'../marl_factory_grid/configs/{cfg["env"]["env_name"]}.yaml')
|
||||
|
||||
# Env Init
|
||||
factory = Factory(env_path)
|
||||
_ = factory.reset()
|
||||
|
||||
# Agent Init
|
||||
cfg['agent'].update(dict(observation_size=list(factory.observation_space[0].shape),
|
||||
n_actions=factory.action_space[0].n))
|
||||
return factory
|
||||
|
||||
|
||||
class Checkpointer(object):
|
||||
|
Reference in New Issue
Block a user