added first working MAPPO implementation

This commit is contained in:
Robert Müller
2022-01-28 11:07:25 +01:00
parent ffc47752a7
commit b09c461754
11 changed files with 194 additions and 61 deletions

View File

@ -12,6 +12,7 @@ class RecurrentAC(nn.Module):
super(RecurrentAC, self).__init__()
observation_size = np.prod(observation_size)
self.n_layers = 1
self.n_actions = n_actions
self.use_agent_embedding = use_agent_embedding
self.hidden_size_actor = hidden_size_actor
self.hidden_size_critic = hidden_size_critic
@ -25,13 +26,14 @@ class RecurrentAC(nn.Module):
nn.Tanh(),
nn.Linear(obs_emb_size, obs_emb_size)
)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
self.action_head = nn.Sequential(
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
nn.Linear(hidden_size_actor, hidden_size_actor),
nn.Tanh(),
nn.Linear(hidden_size_actor, n_actions)
)
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
self.critic_head = nn.Sequential(
nn.Linear(hidden_size_critic, hidden_size_critic),
nn.Tanh(),
@ -50,12 +52,14 @@ class RecurrentAC(nn.Module):
n_agents, t, *_ = observations.shape
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
)
x_t = torch.cat((obs_emb, action_emb), -1) \
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
if not self.use_agent_embedding:
x_t = torch.cat((obs_emb, action_emb), -1)
else:
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
)
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
mixed_x_t = self.mix(x_t)
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
@ -66,6 +70,15 @@ class RecurrentAC(nn.Module):
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
class RecurrentACL2(RecurrentAC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Sequential(
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
nn.Tanh(),
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
)
class NormalizedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int,