93 lines
2.8 KiB
Python
93 lines
2.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.functional as F
|
|
|
|
class AE(nn.Module):
|
|
def __init__(self, in_dim=320):
|
|
super(AE, self).__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(in_dim, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 8),
|
|
nn.ReLU(),
|
|
nn.Linear(8, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 320),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, data):
|
|
return self.net(data)
|
|
|
|
def init_weights(self):
|
|
def _weight_init(m):
|
|
if hasattr(m, 'weight'):
|
|
if isinstance(m.weight, torch.Tensor):
|
|
torch.nn.init.xavier_uniform_(m.weight,
|
|
gain=nn.init.calculate_gain('relu'))
|
|
if hasattr(m, 'bias'):
|
|
if isinstance(m.bias, torch.Tensor):
|
|
m.bias.data.fill_(0.01)
|
|
|
|
self.apply(_weight_init)
|
|
|
|
|
|
|
|
class LCAE(nn.Module):
|
|
def __init__(self, in_dim=320):
|
|
super(LCAE, self).__init__()
|
|
num_mem = 10
|
|
mem_size= 8
|
|
self.num_mem = num_mem
|
|
self.encode = nn.Sequential(
|
|
nn.Linear(in_dim, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, num_mem),
|
|
nn.Softmax(-1)
|
|
)
|
|
|
|
self.decode = nn.Sequential(
|
|
nn.Linear(mem_size, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 320),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
self.M = nn.Parameter(
|
|
torch.randn(num_mem, mem_size)
|
|
)
|
|
|
|
def forward(self, data):
|
|
alphas = self.encode(data).unsqueeze(-1)
|
|
entropy_alphas = (alphas * -alphas.log()).sum(1)
|
|
M = self.M.expand(data.shape[0], *self.M.shape)
|
|
#print(M.shape, alphas.shape) # torch.Size([128, 4, 8]) torch.Size([128, 4, 1])
|
|
elu = nn.ELU()
|
|
weighted = alphas * (1+elu(M+1e-13))
|
|
#print(weighted.shape)
|
|
summed = weighted.sum(1)
|
|
#print(summed.shape)
|
|
decoded = self.decode(summed)
|
|
diversity = (alphas.sum(dim=0)/data.shape[0]).max()
|
|
#print(alphas[0])
|
|
return decoded, entropy_alphas, diversity
|
|
|
|
def init_weights(self):
|
|
def _weight_init(m):
|
|
if hasattr(m, 'weight'):
|
|
if isinstance(m.weight, torch.Tensor):
|
|
torch.nn.init.xavier_uniform_(m.weight,
|
|
gain=nn.init.calculate_gain('relu'))
|
|
if hasattr(m, 'bias'):
|
|
if isinstance(m.bias, torch.Tensor):
|
|
m.bias.data.fill_(0.01)
|
|
|
|
self.apply(_weight_init) |