40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
from salina.agents import Agents, TemporalAgent
|
|
import torch
|
|
from salina import Workspace, get_arguments, get_class, instantiate_class
|
|
from pathlib import Path
|
|
from salina.agents.gyma import GymAgent
|
|
import time
|
|
from algorithms.utils import load_yaml_file, add_env_props
|
|
|
|
if __name__ == '__main__':
|
|
# Setup workspace
|
|
uid = time.time()
|
|
workspace = Workspace()
|
|
weights = Path('/Users/romue/PycharmProjects/EDYS/studies/agent_1636994369.145843.pt')
|
|
|
|
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
|
add_env_props(cfg)
|
|
cfg['env'].update({'n_agents': 2})
|
|
|
|
# instantiate agent and env
|
|
env_agent = GymAgent(
|
|
get_class(cfg['env']),
|
|
get_arguments(cfg['env']),
|
|
n_envs=1
|
|
)
|
|
|
|
agents = []
|
|
for _ in range(2):
|
|
a2c_agent = instantiate_class(cfg['agent'])
|
|
if weights:
|
|
a2c_agent.load_state_dict(torch.load(weights))
|
|
agents.append(a2c_agent)
|
|
|
|
# combine agents
|
|
acquisition_agent = TemporalAgent(Agents(env_agent, *agents))
|
|
acquisition_agent.seed(42)
|
|
|
|
acquisition_agent(workspace, t=0, n_steps=400, stochastic=False, save_render=True)
|
|
|
|
|