added first working MAPPO implementation
This commit is contained in:
@ -12,6 +12,7 @@ class RecurrentAC(nn.Module):
|
||||
super(RecurrentAC, self).__init__()
|
||||
observation_size = np.prod(observation_size)
|
||||
self.n_layers = 1
|
||||
self.n_actions = n_actions
|
||||
self.use_agent_embedding = use_agent_embedding
|
||||
self.hidden_size_actor = hidden_size_actor
|
||||
self.hidden_size_critic = hidden_size_critic
|
||||
@ -25,13 +26,14 @@ class RecurrentAC(nn.Module):
|
||||
nn.Tanh(),
|
||||
nn.Linear(obs_emb_size, obs_emb_size)
|
||||
)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
|
||||
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
|
||||
self.action_head = nn.Sequential(
|
||||
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
nn.Linear(hidden_size_actor, hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size_actor, n_actions)
|
||||
)
|
||||
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
|
||||
self.critic_head = nn.Sequential(
|
||||
nn.Linear(hidden_size_critic, hidden_size_critic),
|
||||
nn.Tanh(),
|
||||
@ -50,12 +52,14 @@ class RecurrentAC(nn.Module):
|
||||
n_agents, t, *_ = observations.shape
|
||||
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
|
||||
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, action_emb), -1) \
|
||||
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
if not self.use_agent_embedding:
|
||||
x_t = torch.cat((obs_emb, action_emb), -1)
|
||||
else:
|
||||
agent_emb = self.agent_emb(
|
||||
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
|
||||
)
|
||||
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
|
||||
|
||||
mixed_x_t = self.mix(x_t)
|
||||
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
|
||||
@ -66,6 +70,15 @@ class RecurrentAC(nn.Module):
|
||||
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
|
||||
|
||||
|
||||
class RecurrentACL2(RecurrentAC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action_head = nn.Sequential(
|
||||
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
|
||||
nn.Tanh(),
|
||||
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
|
||||
)
|
||||
|
||||
|
||||
class NormalizedLinear(nn.Linear):
|
||||
def __init__(self, in_features: int, out_features: int,
|
||||
|
Reference in New Issue
Block a user