add more modularity
This commit is contained in:
11
models/ae.py
11
models/ae.py
@ -24,6 +24,17 @@ class AE(nn.Module):
|
||||
x = data.view(data.shape[0], -1)
|
||||
return self.net(x), x
|
||||
|
||||
def train_loss(self, data):
|
||||
criterion = nn.MSELoss()
|
||||
y_hat, y = self.forward(data)
|
||||
loss = criterion(y_hat, y)
|
||||
return loss
|
||||
|
||||
def test_loss(self, data):
|
||||
y_hat, y = self.forward(data)
|
||||
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
||||
return preds
|
||||
|
||||
def init_weights(self):
|
||||
def _weight_init(m):
|
||||
if hasattr(m, 'weight'):
|
||||
|
Reference in New Issue
Block a user