deleted policy daptiom, added IAC
This commit is contained in:
@ -1,100 +1,133 @@
|
||||
from environments.factory import make
|
||||
from salina import Workspace, TAgent
|
||||
from salina.agents.gyma import AutoResetGymAgent, GymAgent
|
||||
from salina.agents.gyma import AutoResetGymAgent
|
||||
from salina.agents import Agents, TemporalAgent
|
||||
from salina.rl.functional import _index
|
||||
from salina.rl.functional import _index, gae
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import spectral_norm
|
||||
import torch.optim as optim
|
||||
from torch.distributions import Categorical
|
||||
from salina import TAgent, Workspace, get_arguments, get_class, instantiate_class
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
from algorithms.utils import add_env_props, load_yaml_file, CombineActionsAgent
|
||||
|
||||
|
||||
class A2CAgent(TAgent):
|
||||
def __init__(self, observation_size, hidden_size, n_actions):
|
||||
def __init__(self, observation_size, hidden_size, n_actions, agent_id=-1, marl=False):
|
||||
super().__init__()
|
||||
observation_size = np.prod(observation_size)
|
||||
self.agent_id = agent_id
|
||||
self.marl = marl
|
||||
self.model = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(observation_size, hidden_size),
|
||||
nn.ELU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ELU(),
|
||||
nn.Linear(hidden_size, n_actions),
|
||||
)
|
||||
self.critic_model = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(observation_size, hidden_size),
|
||||
nn.ELU(),
|
||||
spectral_norm(nn.Linear(hidden_size, 1)),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ELU()
|
||||
)
|
||||
self.action_head = nn.Linear(hidden_size, n_actions)
|
||||
self.critic_head = nn.Linear(hidden_size, 1)
|
||||
|
||||
def get_obs(self, t):
|
||||
observation = self.get(("env/env_obs", t))
|
||||
if self.marl:
|
||||
observation = observation.permute(2, 0, 1, 3, 4, 5)
|
||||
observation = observation[self.agent_id]
|
||||
return observation
|
||||
|
||||
def forward(self, t, stochastic, **kwargs):
|
||||
observation = self.get(("env/env_obs", t))
|
||||
scores = self.model(observation)
|
||||
observation = self.get_obs(t)
|
||||
features = self.model(observation)
|
||||
scores = self.action_head(features)
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
critic = self.critic_model(observation).squeeze(-1)
|
||||
critic = self.critic_head(features).squeeze(-1)
|
||||
if stochastic:
|
||||
action = torch.distributions.Categorical(probs).sample()
|
||||
else:
|
||||
action = probs.argmax(1)
|
||||
|
||||
self.set(("action", t), action)
|
||||
self.set(("action_probs", t), probs)
|
||||
self.set(("critic", t), critic)
|
||||
agent_str = f'agent{self.agent_id}_'
|
||||
self.set((f'{agent_str}action', t), action)
|
||||
self.set((f'{agent_str}action_probs', t), probs)
|
||||
self.set((f'{agent_str}critic', t), critic)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Setup agents and workspace
|
||||
env_agent = AutoResetGymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1)
|
||||
a2c_agent = A2CAgent(3*4*5*5, 96, 10)
|
||||
# Setup workspace
|
||||
uid = time.time()
|
||||
workspace = Workspace()
|
||||
n_agents = 1
|
||||
|
||||
eval_agent = Agents(GymAgent(make, dict(env_str='DirtyFactory-v0'), n_envs=1), a2c_agent)
|
||||
for i in range(100):
|
||||
eval_agent(workspace, t=i, save_render=True, stochastic=True)
|
||||
# load config
|
||||
cfg = load_yaml_file(Path(__file__).parent / 'sat_mad.yaml')
|
||||
add_env_props(cfg)
|
||||
cfg['env'].update({'n_agents': n_agents})
|
||||
|
||||
# instantiate agent and env
|
||||
env_agent = AutoResetGymAgent(
|
||||
get_class(cfg['env']),
|
||||
get_arguments(cfg['env']),
|
||||
n_envs=1
|
||||
)
|
||||
|
||||
a2c_agents = [instantiate_class({**cfg['agent'],
|
||||
'agent_id': agent_id,
|
||||
'marl': n_agents > 1})
|
||||
for agent_id in range(n_agents)]
|
||||
|
||||
assert False
|
||||
# combine agents
|
||||
acquisition_agent = TemporalAgent(Agents(env_agent, a2c_agent))
|
||||
acquisition_agent.seed(0)
|
||||
acquisition_agent = TemporalAgent(Agents(env_agent, *a2c_agents, CombineActionsAgent()))
|
||||
acquisition_agent.seed(69)
|
||||
|
||||
# optimizers & other parameters
|
||||
optimizer = optim.Adam(a2c_agent.parameters(), lr=1e-3)
|
||||
n_timesteps = 10
|
||||
cfg_optim = cfg['algorithm']['optimizer']
|
||||
optimizers = [get_class(cfg_optim)(a2c_agent.parameters(), **get_arguments(cfg_optim))
|
||||
for a2c_agent in a2c_agents]
|
||||
n_timesteps = cfg['algorithm']['n_timesteps']
|
||||
|
||||
# Decision making loop
|
||||
for epoch in range(200000):
|
||||
workspace.zero_grad()
|
||||
if epoch > 0:
|
||||
workspace.copy_n_last_steps(1)
|
||||
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
||||
else:
|
||||
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
||||
#for k in workspace.keys():
|
||||
# print(f'{k} ==> {workspace[k].size()}')
|
||||
critic, done, action_probs, reward, action = workspace[
|
||||
"critic", "env/done", "action_probs", "env/reward", "action"
|
||||
]
|
||||
best = -float('inf')
|
||||
with tqdm(range(int(cfg['algorithm']['max_epochs'] / n_timesteps))) as pbar:
|
||||
for epoch in pbar:
|
||||
workspace.zero_grad()
|
||||
if epoch > 0:
|
||||
workspace.copy_n_last_steps(1)
|
||||
acquisition_agent(workspace, t=1, n_steps=n_timesteps-1, stochastic=True)
|
||||
else:
|
||||
acquisition_agent(workspace, t=0, n_steps=n_timesteps, stochastic=True)
|
||||
|
||||
target = reward[1:] + 0.99 * critic[1:].detach() * (1 - done[1:].float())
|
||||
td = target - critic[:-1]
|
||||
td_error = td ** 2
|
||||
critic_loss = td_error.mean()
|
||||
entropy_loss = Categorical(action_probs).entropy().mean()
|
||||
action_logp = _index(action_probs, action).log()
|
||||
a2c_loss = action_logp[:-1] * td.detach()
|
||||
a2c_loss = a2c_loss.mean()
|
||||
loss = (
|
||||
-0.001 * entropy_loss
|
||||
+ 1.0 * critic_loss
|
||||
- 0.1 * a2c_loss
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
for agent_id in range(n_agents):
|
||||
critic, done, action_probs, reward, action = workspace[
|
||||
f"agent{agent_id}_critic", "env/done",
|
||||
f'agent{agent_id}_action_probs', "env/reward",
|
||||
f"agent{agent_id}_action"
|
||||
]
|
||||
td = gae(critic, reward, done, 0.99, 0.3)
|
||||
td_error = td ** 2
|
||||
critic_loss = td_error.mean()
|
||||
entropy_loss = Categorical(action_probs).entropy().mean()
|
||||
action_logp = _index(action_probs, action).log()
|
||||
a2c_loss = action_logp[:-1] * td.detach()
|
||||
a2c_loss = a2c_loss.mean()
|
||||
loss = (
|
||||
-0.001 * entropy_loss
|
||||
+ 1.0 * critic_loss
|
||||
- 0.1 * a2c_loss
|
||||
)
|
||||
optimizer = optimizers[agent_id]
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
#torch.nn.utils.clip_grad_norm_(a2c_agents[agent_id].parameters(), 2)
|
||||
optimizer.step()
|
||||
|
||||
# Compute the cumulated reward on final_state
|
||||
creward = workspace["env/cumulated_reward"]
|
||||
creward = creward[done]
|
||||
if creward.size()[0] > 0:
|
||||
cum_r = creward.mean().item()
|
||||
if cum_r > best:
|
||||
# torch.save(a2c_agent.state_dict(), Path(__file__).parent / f'agent_{uid}.pt')
|
||||
best = cum_r
|
||||
pbar.set_description(f"Cum. r: {cum_r:.2f}, Best r. so far: {best:.2f}", refresh=True)
|
||||
|
||||
# Compute the cumulated reward on final_state
|
||||
creward = workspace["env/cumulated_reward"]
|
||||
creward = creward[done]
|
||||
if creward.size()[0] > 0:
|
||||
print(f"Cumulative reward at A2C step #{(1+epoch)*n_timesteps}: {creward.mean().item()}")
|
Reference in New Issue
Block a user