add more modularity

This commit is contained in:
Robert Müller
2020-03-18 17:53:52 +01:00
parent 55402d219c
commit f4606a7f6c
4 changed files with 58 additions and 51 deletions

View File

@ -24,6 +24,17 @@ class AE(nn.Module):
x = data.view(data.shape[0], -1)
return self.net(x), x
def train_loss(self, data):
criterion = nn.MSELoss()
y_hat, y = self.forward(data)
loss = criterion(y_hat, y)
return loss
def test_loss(self, data):
y_hat, y = self.forward(data)
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
return preds
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):