2020-03-18 13:09:39 +01:00

93 lines
2.8 KiB
Python

import torch
import torch.nn as nn
import torch.functional as F
class AE(nn.Module):
def __init__(self, in_dim=320):
super(AE, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 320),
nn.ReLU(),
)
def forward(self, data):
return self.net(data)
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight,
gain=nn.init.calculate_gain('relu'))
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
class LCAE(nn.Module):
def __init__(self, in_dim=320):
super(LCAE, self).__init__()
num_mem = 10
mem_size= 8
self.num_mem = num_mem
self.encode = nn.Sequential(
nn.Linear(in_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, num_mem),
nn.Softmax(-1)
)
self.decode = nn.Sequential(
nn.Linear(mem_size, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 320),
nn.ReLU(),
)
self.M = nn.Parameter(
torch.randn(num_mem, mem_size)
)
def forward(self, data):
alphas = self.encode(data).unsqueeze(-1)
entropy_alphas = (alphas * -alphas.log()).sum(1)
M = self.M.expand(data.shape[0], *self.M.shape)
#print(M.shape, alphas.shape) # torch.Size([128, 4, 8]) torch.Size([128, 4, 1])
elu = nn.ELU()
weighted = alphas * (1+elu(M+1e-13))
#print(weighted.shape)
summed = weighted.sum(1)
#print(summed.shape)
decoded = self.decode(summed)
diversity = (alphas.sum(dim=0)/data.shape[0]).max()
#print(alphas[0])
return decoded, entropy_alphas, diversity
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight,
gain=nn.init.calculate_gain('relu'))
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)