48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.functional as F
|
|
|
|
class AE(nn.Module):
|
|
def __init__(self, in_dim=400):
|
|
super(AE, self).__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(in_dim, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 8),
|
|
nn.ReLU(),
|
|
nn.Linear(8, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, 64),
|
|
nn.ReLU(),
|
|
nn.Linear(64, in_dim),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, data):
|
|
x = data.view(data.shape[0], -1)
|
|
return self.net(x), x
|
|
|
|
def train_loss(self, data):
|
|
criterion = nn.MSELoss()
|
|
y_hat, y = self.forward(data)
|
|
loss = criterion(y_hat, y)
|
|
return loss
|
|
|
|
def test_loss(self, data):
|
|
y_hat, y = self.forward(data)
|
|
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
|
return preds
|
|
|
|
def init_weights(self):
|
|
def _weight_init(m):
|
|
if hasattr(m, 'weight'):
|
|
if isinstance(m.weight, torch.Tensor):
|
|
torch.nn.init.xavier_uniform_(m.weight,
|
|
gain=nn.init.calculate_gain('relu'))
|
|
if hasattr(m, 'bias'):
|
|
if isinstance(m.bias, torch.Tensor):
|
|
m.bias.data.fill_(0.01)
|
|
|
|
self.apply(_weight_init) |