firs commit for our new MARL algorithms library, contains working implementations of IAC, SNAC and SEAC
This commit is contained in:
4
algorithms/marl/__init__.py
Normal file
4
algorithms/marl/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from algorithms.marl.base_ac import BaseActorCritic
|
||||
from algorithms.marl.iac import LoopIAC
|
||||
from algorithms.marl.snac import LoopSNAC
|
||||
from algorithms.marl.seac import LoopSEAC
|
176
algorithms/marl/base_ac.py
Normal file
176
algorithms/marl/base_ac.py
Normal file
@ -0,0 +1,176 @@
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import numpy as np
|
||||
from torch.distributions import Categorical
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
from algorithms.utils import add_env_props, instantiate_class
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
from collections import deque
|
||||
ListOrTensor = Union[List, torch.Tensor]
|
||||
|
||||
|
||||
class BaseActorCritic:
|
||||
def __init__(self, cfg):
|
||||
add_env_props(cfg)
|
||||
self.__training = True
|
||||
self.cfg = cfg
|
||||
self.n_agents = cfg['env']['n_agents']
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
self.net = instantiate_class(self.cfg['agent'])
|
||||
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
|
||||
|
||||
@classmethod
|
||||
def _as_torch(cls, x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x)
|
||||
elif isinstance(x, List):
|
||||
return torch.tensor(x)
|
||||
elif isinstance(x, (int, float)):
|
||||
return torch.tensor([x])
|
||||
return x
|
||||
|
||||
def train(self):
|
||||
self.__training = False
|
||||
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||
for net in networks:
|
||||
net.train()
|
||||
|
||||
def eval(self):
|
||||
self.__training = False
|
||||
networks = [self.net] if not isinstance(self.net, List) else self.net
|
||||
for net in networks:
|
||||
net.eval()
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
pass
|
||||
|
||||
def get_actions(self, out) -> ListOrTensor:
|
||||
actions = [Categorical(logits=logits).sample().item() for logits in out['logits']]
|
||||
return actions
|
||||
|
||||
def init_hidden(self) -> dict[ListOrTensor]:
|
||||
pass
|
||||
|
||||
def forward(self,
|
||||
observations: ListOrTensor,
|
||||
actions: ListOrTensor,
|
||||
hidden_actor: ListOrTensor,
|
||||
hidden_critic: ListOrTensor
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def train_loop(self, checkpointer=None):
|
||||
env = instantiate_class(self.cfg['env'])
|
||||
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
|
||||
global_steps = 0
|
||||
reward_queue = deque(maxlen=2000)
|
||||
while global_steps < max_steps:
|
||||
tm = MARLActorCriticMemory(self.n_agents)
|
||||
obs = env.reset()
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log = [False] * self.n_agents, 0
|
||||
tm.add(action=last_action, **last_hiddens)
|
||||
|
||||
while not all(done):
|
||||
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
next_obs = next_obs
|
||||
if isinstance(done, bool): done = [done] * self.n_agents
|
||||
|
||||
tm.add(observation=obs, action=action, reward=reward, done=done)
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
|
||||
hidden_critic=out.get('hidden_critic', None)
|
||||
)
|
||||
|
||||
if len(tm) >= n_steps or all(done):
|
||||
tm.add(observation=next_obs)
|
||||
if self.__training:
|
||||
with torch.inference_mode(False):
|
||||
self.learn(tm)
|
||||
tm.reset()
|
||||
tm.add(action=last_action, **last_hiddens)
|
||||
global_steps += 1
|
||||
rew_log += sum(reward)
|
||||
reward_queue.extend(reward)
|
||||
|
||||
if checkpointer is not None:
|
||||
checkpointer.step([
|
||||
(f'agent#{i}', agent)
|
||||
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
|
||||
])
|
||||
|
||||
if global_steps >= max_steps: break
|
||||
print(f'reward at step: {global_steps} = {rew_log}')
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
env = instantiate_class(self.cfg['env'])
|
||||
episode, results = 0, []
|
||||
while episode < n_episodes:
|
||||
obs = env.reset()
|
||||
last_hiddens = self.init_hidden()
|
||||
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
|
||||
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
|
||||
while not all(done):
|
||||
if render: env.render()
|
||||
|
||||
out = self.forward(obs, last_action, **last_hiddens)
|
||||
action = self.get_actions(out)
|
||||
next_obs, reward, done, info = env.step(action)
|
||||
|
||||
if isinstance(done, bool): done = [done] * obs.shape[0]
|
||||
obs = next_obs
|
||||
last_action = action
|
||||
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
|
||||
hidden_critic=out.get('hidden_critic', None)
|
||||
)
|
||||
eps_rew += torch.tensor(reward)
|
||||
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
|
||||
episode += 1
|
||||
agent_columns = [f'agent#{i}' for i in range(self.cfg['env']['n_agents'])]
|
||||
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
|
||||
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'], value_name='reward', var_name='agent')
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def compute_advantages(critic, reward, done, gamma):
|
||||
return (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
|
||||
|
||||
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
||||
|
||||
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
|
||||
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out['critic']
|
||||
|
||||
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma)
|
||||
value_loss = advantages.pow(2).mean(-1) # n_agent
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
|
||||
# weighted loss
|
||||
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def learn(self, tm: MARLActorCriticMemory, **kwargs):
|
||||
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
|
||||
# remove next_obs, will be added in next iter
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
|
||||
self.optimizer.step()
|
||||
|
24
algorithms/marl/example_config.yaml
Normal file
24
algorithms/marl/example_config.yaml
Normal file
@ -0,0 +1,24 @@
|
||||
agent:
|
||||
classname: algorithms.marl.networks.RecurrentAC
|
||||
n_agents: 2
|
||||
obs_emb_size: 96
|
||||
action_emb_size: 16
|
||||
hidden_size_actor: 64
|
||||
hidden_size_critic: 64
|
||||
use_agent_embedding: False
|
||||
env:
|
||||
classname: environments.factory.make
|
||||
env_name: "DirtyFactory-v0"
|
||||
n_agents: 2
|
||||
max_steps: 250
|
||||
pomdp_r: 2
|
||||
stack_n_frames: 0
|
||||
individual_rewards: True
|
||||
method: algorithms.marl.LoopSEAC
|
||||
algorithm:
|
||||
gamma: 0.99
|
||||
entropy_coef: 0.01
|
||||
vf_coef: 0.5
|
||||
n_steps: 5
|
||||
max_steps: 1000000
|
||||
|
58
algorithms/marl/iac.py
Normal file
58
algorithms/marl/iac.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from algorithms.marl.base_ac import BaseActorCritic
|
||||
from algorithms.utils import instantiate_class
|
||||
from pathlib import Path
|
||||
from natsort import natsorted
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
class LoopIAC(BaseActorCritic):
|
||||
|
||||
def __init__(self, cfg):
|
||||
super(LoopIAC, self).__init__(cfg)
|
||||
|
||||
def setup(self):
|
||||
self.net = [
|
||||
instantiate_class(self.cfg['agent']) for _ in range(self.n_agents)
|
||||
]
|
||||
self.optimizer = [
|
||||
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
|
||||
]
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
paths = natsorted(list(path.glob('*.pt')))
|
||||
print(list(paths))
|
||||
for path, net in zip(paths, self.net):
|
||||
net.load_state_dict(torch.load(path))
|
||||
|
||||
@staticmethod
|
||||
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
|
||||
d = {}
|
||||
for k in ds[0].keys():
|
||||
d[k] = [d[k] for d in ds]
|
||||
return d
|
||||
|
||||
def init_hidden(self):
|
||||
ha = [net.init_hidden_actor() for net in self.net]
|
||||
hc = [net.init_hidden_critic() for net in self.net]
|
||||
return dict(hidden_actor=ha, hidden_critic=hc)
|
||||
|
||||
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||
outputs = [
|
||||
net(
|
||||
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agents x time
|
||||
self._as_torch(actions[ag_i]).unsqueeze(0),
|
||||
hidden_actor[ag_i],
|
||||
hidden_critic[ag_i]
|
||||
) for ag_i, net in enumerate(self.net)
|
||||
]
|
||||
return self.merge_dicts(outputs)
|
||||
|
||||
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||
for ag_i in range(self.n_agents):
|
||||
tm, net = tms(ag_i), self.net[ag_i]
|
||||
loss = self.actor_critic(tm, net, **self.cfg['algorithm'], **kwargs)
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
131
algorithms/marl/memory.py
Normal file
131
algorithms/marl/memory.py
Normal file
@ -0,0 +1,131 @@
|
||||
import torch
|
||||
from typing import Union, List
|
||||
from torch import Tensor
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ActorCriticMemory(object):
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.__states = []
|
||||
self.__actions = []
|
||||
self.__rewards = []
|
||||
self.__dones = []
|
||||
self.__hiddens_actor = []
|
||||
self.__hiddens_critic = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__states)
|
||||
|
||||
@property
|
||||
def observation(self):
|
||||
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def hidden_actor(self):
|
||||
if len(self.__hiddens_actor) == 1:
|
||||
return self.__hiddens_actor[0]
|
||||
return torch.stack(self.__hiddens_actor, 0) # layers x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def hidden_critic(self):
|
||||
if len(self.__hiddens_critic) == 1:
|
||||
return self.__hiddens_critic[0]
|
||||
return torch.stack(self.__hiddens_critic, 0) # layers x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def reward(self):
|
||||
return torch.tensor(self.__rewards).float().unsqueeze(0) # 1 x timesteps
|
||||
|
||||
@property
|
||||
def action(self):
|
||||
return torch.tensor(self.__actions).long().unsqueeze(0) # 1 x timesteps+1
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
|
||||
|
||||
def add_observation(self, state: Union[Tensor, np.ndarray]):
|
||||
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
|
||||
|
||||
def add_hidden_actor(self, hidden: Tensor):
|
||||
# 1x layers x hidden dim
|
||||
if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0)
|
||||
self.__hiddens_actor.append(hidden)
|
||||
|
||||
def add_hidden_critic(self, hidden: Tensor):
|
||||
# 1x layers x hidden dim
|
||||
if len(hidden.shape) < 3: hidden = hidden.unsqueeze(0)
|
||||
self.__hiddens_critic.append(hidden)
|
||||
|
||||
def add_action(self, action: int):
|
||||
self.__actions.append(action)
|
||||
|
||||
def add_reward(self, reward: float):
|
||||
self.__rewards.append(reward)
|
||||
|
||||
def add_done(self, done: bool):
|
||||
self.__dones.append(done)
|
||||
|
||||
def add(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
func(self, v)
|
||||
|
||||
|
||||
class MARLActorCriticMemory(object):
|
||||
def __init__(self, n_agents):
|
||||
self.n_agents = n_agents
|
||||
self.memories = [
|
||||
ActorCriticMemory() for _ in range(n_agents)
|
||||
]
|
||||
|
||||
def __call__(self, agent_i):
|
||||
return self.memories[agent_i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memories[0]) # todo add assertion check!
|
||||
|
||||
def reset(self):
|
||||
for mem in self.memories:
|
||||
mem.reset()
|
||||
|
||||
def add(self, **kwargs):
|
||||
# todo try catch - print all possible functions
|
||||
for agent_i in range(self.n_agents):
|
||||
for k, v in kwargs.items():
|
||||
func = getattr(ActorCriticMemory, f'add_{k}')
|
||||
func(self.memories[agent_i], v[agent_i])
|
||||
|
||||
@property
|
||||
def observation(self):
|
||||
all_obs = [mem.observation for mem in self.memories]
|
||||
return torch.cat(all_obs, 0) # agents x timesteps+1 x ...
|
||||
|
||||
@property
|
||||
def action(self):
|
||||
all_actions = [mem.action for mem in self.memories]
|
||||
return torch.cat(all_actions, 0) # agents x timesteps+1 x ...
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
all_dones = [mem.done for mem in self.memories]
|
||||
return torch.cat(all_dones, 0).float() # agents x timesteps x ...
|
||||
|
||||
@property
|
||||
def reward(self):
|
||||
all_rewards = [mem.reward for mem in self.memories]
|
||||
return torch.cat(all_rewards, 0).float() # agents x timesteps x ...
|
||||
|
||||
@property
|
||||
def hidden_actor(self):
|
||||
all_ha = [mem.hidden_actor for mem in self.memories]
|
||||
return torch.cat(all_ha, 0) # agents x layers x x timesteps x hidden dim
|
||||
|
||||
@property
|
||||
def hidden_critic(self):
|
||||
all_hc = [mem.hidden_critic for mem in self.memories]
|
||||
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
|
||||
|
91
algorithms/marl/networks.py
Normal file
91
algorithms/marl/networks.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
|
||||
class RecurrentAC(nn.Module):
|
||||
def __init__(self, observation_size, n_actions, obs_emb_size,
|
||||
action_emb_size, hidden_size_actor, hidden_size_critic,
|
||||
n_agents, use_agent_embedding=True):
|
||||
super(RecurrentAC, self).__init__()
|
||||
observation_size = np.prod(observation_size)
|
||||
self.n_layers = 1
|
||||
self.use_agent_embedding = use_agent_embedding
|
||||
self.hidden_size_actor = hidden_size_actor
|
||||
self.hidden_size_critic = hidden_size_critic
|
||||
self.action_emb_size = action_emb_size
|
||||
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
|
||||
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
|
||||
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
|
||||
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
|
||||
self.mix = nn.Sequential(nn.Tanh(),
|
||||
nn.Linear(mix_in_size, obs_emb_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(obs_emb_size, obs_emb_size)
|
||||
)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||
self.action_head = nn.Sequential(
|
||||
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_actor, n_actions)
|
||||
)
|
||||
self.critic_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_critic, 1)
|
||||
)
|
||||
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
|
||||
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
|
||||
|
||||
def init_hidden_actor(self):
|
||||
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
|
||||
|
||||
def init_hidden_critic(self):
|
||||
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
|
||||
|
||||
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
|
||||
n_agents, t, *_ = observations.shape
|
||||
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, action_emb), -1) \
|
||||
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
|
||||
mixed_x_t = self.mix(x_t)
|
||||
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
|
||||
|
||||
logits = self.action_head(output_p)
|
||||
critic = self.critic_head(output_c).squeeze(-1)
|
||||
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||
|
||||
|
||||
|
||||
class NormalizedLinear(nn.Linear):
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
device=None, dtype=None, trainable_magnitude=False):
|
||||
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
|
||||
self.d_sqrt = in_features**0.5
|
||||
self.trainable_magnitude = trainable_magnitude
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, input):
|
||||
normalized_input = F.normalize(input, dim=-1, p=2, eps=1e-5)
|
||||
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
|
||||
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def __init__(self, in_features, trainable_magnitude=False):
|
||||
super(L2Norm, self).__init__()
|
||||
self.d_sqrt = in_features**0.5
|
||||
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale
|
55
algorithms/marl/seac.py
Normal file
55
algorithms/marl/seac.py
Normal file
@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from algorithms.marl.iac import LoopIAC
|
||||
from algorithms.marl.memory import MARLActorCriticMemory
|
||||
|
||||
|
||||
class LoopSEAC(LoopIAC):
|
||||
def __init__(self, cfg):
|
||||
super(LoopSEAC, self).__init__(cfg)
|
||||
|
||||
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, **kwargs):
|
||||
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
||||
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
|
||||
|
||||
with torch.inference_mode(True):
|
||||
true_action_logp = torch.stack([
|
||||
torch.log_softmax(out['logits'][ag_i, :-1], -1)
|
||||
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||
for ag_i, out in enumerate(outputs)
|
||||
], 0).squeeze()
|
||||
|
||||
losses = []
|
||||
|
||||
for ag_i, out in enumerate(outputs):
|
||||
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
|
||||
critic = out['critic']
|
||||
|
||||
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
||||
advantages = self.compute_advantages(critic, reward, done, gamma)
|
||||
|
||||
# policy loss
|
||||
log_ap = torch.log_softmax(logits, -1)
|
||||
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
||||
|
||||
# importance weights
|
||||
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
|
||||
|
||||
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||
|
||||
|
||||
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
|
||||
losses.append(loss)
|
||||
|
||||
return losses
|
||||
|
||||
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
||||
losses = self.actor_critic(tms, self.net, **self.cfg['algorithm'], **kwargs)
|
||||
for ag_i, loss in enumerate(losses):
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
32
algorithms/marl/snac.py
Normal file
32
algorithms/marl/snac.py
Normal file
@ -0,0 +1,32 @@
|
||||
from algorithms.marl.base_ac import BaseActorCritic
|
||||
import torch
|
||||
from torch.distributions import Categorical
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class LoopSNAC(BaseActorCritic):
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
|
||||
def load_state_dict(self, path: Path):
|
||||
path2weights = list(path.glob('*.pt'))
|
||||
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
|
||||
self.net.load_state_dict(torch.load(path2weights[0]))
|
||||
|
||||
def init_hidden(self):
|
||||
hidden_actor = self.net.init_hidden_actor()
|
||||
hidden_critic = self.net.init_hidden_critic()
|
||||
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
|
||||
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
|
||||
)
|
||||
|
||||
def get_actions(self, out):
|
||||
actions = Categorical(logits=out['logits']).sample().squeeze()
|
||||
return actions
|
||||
|
||||
def forward(self, observations, actions, hidden_actor, hidden_critic):
|
||||
out = self.net(self._as_torch(observations).unsqueeze(1),
|
||||
self._as_torch(actions).unsqueeze(1),
|
||||
hidden_actor, hidden_critic
|
||||
)
|
||||
return out
|
Reference in New Issue
Block a user