2020-03-18 17:53:52 +01:00

48 lines
1.4 KiB
Python

import torch
import torch.nn as nn
import torch.functional as F
class AE(nn.Module):
def __init__(self, in_dim=400):
super(AE, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, in_dim),
nn.ReLU(),
)
def forward(self, data):
x = data.view(data.shape[0], -1)
return self.net(x), x
def train_loss(self, data):
criterion = nn.MSELoss()
y_hat, y = self.forward(data)
loss = criterion(y_hat, y)
return loss
def test_loss(self, data):
y_hat, y = self.forward(data)
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
return preds
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight,
gain=nn.init.calculate_gain('relu'))
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)