140 lines
5.2 KiB
Python
140 lines
5.2 KiB
Python
from salina.agents.gyma import AutoResetGymAgent
|
|
from salina.agents import Agents, TemporalAgent
|
|
from salina.rl.functional import _index, gae
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributions import Categorical
|
|
from salina import TAgent, Workspace, get_arguments, get_class, instantiate_class
|
|
from pathlib import Path
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import time
|
|
from algorithms.utils import (
|
|
add_env_props,
|
|
load_yaml_file,
|
|
CombineActionsAgent,
|
|
AutoResetGymMultiAgent,
|
|
access_str,
|
|
AGENT_PREFIX, REWARD, CUMU_REWARD, OBS, SEP
|
|
)
|
|
|
|
|
|
class A2CAgent(TAgent):
|
|
def __init__(self, observation_size, hidden_size, n_actions, agent_id):
|
|
super().__init__()
|
|
observation_size = np.prod(observation_size)
|
|
print(observation_size)
|
|
self.agent_id = agent_id
|
|
self.model = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(observation_size, hidden_size),
|
|
nn.ELU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.ELU(),
|
|
nn.Linear(hidden_size, hidden_size),
|
|
nn.ELU()
|
|
)
|
|
self.action_head = nn.Linear(hidden_size, n_actions)
|
|
self.critic_head = nn.Linear(hidden_size, 1)
|
|
|
|
def get_obs(self, t):
|
|
observation = self.get((f'env/{access_str(self.agent_id, OBS)}', t))
|
|
return observation
|
|
|
|
def forward(self, t, stochastic, **kwargs):
|
|
observation = self.get_obs(t)
|
|
features = self.model(observation)
|
|
scores = self.action_head(features)
|
|
probs = torch.softmax(scores, dim=-1)
|
|
critic = self.critic_head(features).squeeze(-1)
|
|
if stochastic:
|
|
action = torch.distributions.Categorical(probs).sample()
|
|
else:
|
|
action = probs.argmax(1)
|
|
self.set((f'{access_str(self.agent_id, "action")}', t), action)
|
|
self.set((f'{access_str(self.agent_id, "action_probs")}', t), probs)
|
|
self.set((f'{access_str(self.agent_id, "critic")}', t), critic)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Setup workspace
|
|
uid = time.time()
|
|
workspace = Workspace()
|
|
n_agents = 2
|
|
|
|
# load config
|
|
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
|
add_env_props(cfg)
|
|
cfg['env'].update({'n_agents': n_agents})
|
|
|
|
# instantiate agent and env
|
|
env_agent = AutoResetGymMultiAgent(
|
|
get_class(cfg['env']),
|
|
get_arguments(cfg['env']),
|
|
n_envs=1
|
|
)
|
|
|
|
a2c_agents = [instantiate_class({**cfg['agent'],
|
|
'agent_id': agent_id})
|
|
for agent_id in range(n_agents)]
|
|
|
|
# combine agents
|
|
acquisition_agent = TemporalAgent(Agents(env_agent, *a2c_agents, CombineActionsAgent()))
|
|
acquisition_agent.seed(69)
|
|
|
|
# optimizers & other parameters
|
|
cfg_optim = cfg['algorithm']['optimizer']
|
|
optimizers = [get_class(cfg_optim)(a2c_agent.parameters(), **get_arguments(cfg_optim))
|
|
for a2c_agent in a2c_agents]
|
|
n_timesteps = cfg['algorithm']['n_timesteps']
|
|
|
|
# Decision making loop
|
|
best = -float('inf')
|
|
with tqdm(range(int(cfg['algorithm']['max_epochs'] / n_timesteps))) as pbar:
|
|
for epoch in pbar:
|
|
workspace.zero_grad()
|
|
if epoch > 0:
|
|
workspace.copy_n_last_steps(1)
|
|
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
|
else:
|
|
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
|
|
|
for agent_id in range(n_agents):
|
|
critic, done, action_probs, reward, action = workspace[
|
|
access_str(agent_id, 'critic'),
|
|
"env/done",
|
|
access_str(agent_id, 'action_probs'),
|
|
access_str(agent_id, 'reward', 'env/'),
|
|
access_str(agent_id, 'action')
|
|
]
|
|
td = gae(critic, reward, done, 0.98, 0.25)
|
|
td_error = td ** 2
|
|
critic_loss = td_error.mean()
|
|
entropy_loss = Categorical(action_probs).entropy().mean()
|
|
action_logp = _index(action_probs, action).log()
|
|
a2c_loss = action_logp[:-1] * td.detach()
|
|
a2c_loss = a2c_loss.mean()
|
|
loss = (
|
|
-0.001 * entropy_loss
|
|
+ 1.0 * critic_loss
|
|
- 0.1 * a2c_loss
|
|
)
|
|
optimizer = optimizers[agent_id]
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
#torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), .5)
|
|
optimizer.step()
|
|
|
|
# Compute the cumulated reward on final_state
|
|
rews = ''
|
|
for agent_i in range(n_agents):
|
|
creward = workspace['env/'+access_str(agent_i, CUMU_REWARD)]
|
|
creward = creward[done]
|
|
if creward.size()[0] > 0:
|
|
rews += f'{AGENT_PREFIX}{agent_i}: {creward.mean().item():.2f} | '
|
|
"""if cum_r > best:
|
|
torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
|
|
best = cum_r"""
|
|
pbar.set_description(rews, refresh=True)
|
|
|