2020-03-18 17:13:00 +01:00

37 lines
1.1 KiB
Python

import torch
import torch.nn as nn
import torch.functional as F
class AE(nn.Module):
def __init__(self, in_dim=400):
super(AE, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, in_dim),
nn.ReLU(),
)
def forward(self, data):
x = data.view(data.shape[0], -1)
return self.net(x), x
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight,
gain=nn.init.calculate_gain('relu'))
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)