add more modularity
This commit is contained in:
parent
55402d219c
commit
f4606a7f6c
4
cfg.py
4
cfg.py
@ -2,8 +2,8 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
BATCH_SIZE = 128
|
||||
NUM_EPOCHS = 1
|
||||
NUM_WORKERS = 0
|
||||
NUM_EPOCHS = 10
|
||||
NUM_WORKERS = 4
|
||||
NUM_SEGMENTS = 5
|
||||
NUM_SEGMENT_HOPS = 2
|
||||
SEEDS = [42, 1337]
|
||||
|
9
main.py
9
main.py
@ -1,3 +1,4 @@
|
||||
if __name__ == '__main__':
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from cfg import *
|
||||
@ -29,13 +30,10 @@ dl = mimii.train_dataloader(
|
||||
|
||||
model = AE(400).to(DEVICE)
|
||||
model.init_weights()
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
|
||||
beta_1 = 0.00
|
||||
beta_2 = 0.0
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
print(f'EPOCH #{epoch+1}')
|
||||
losses = []
|
||||
@ -43,8 +41,7 @@ for epoch in range(NUM_EPOCHS):
|
||||
data, labels = batch
|
||||
data = data.to(DEVICE)
|
||||
|
||||
y_hat, y = model(data)
|
||||
loss = criterion(y_hat, y)
|
||||
loss = model.train_loss(data)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
3
mimii.py
3
mimii.py
@ -92,8 +92,7 @@ class MIMII(object):
|
||||
data, labels = batch
|
||||
data = data.to(self.device)
|
||||
|
||||
y_hat, y = f(data)
|
||||
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
||||
preds = f.test_loss(data)
|
||||
|
||||
file_preds += preds.cpu().data.tolist()
|
||||
y_true.append(labels.max().item())
|
||||
|
11
models/ae.py
11
models/ae.py
@ -24,6 +24,17 @@ class AE(nn.Module):
|
||||
x = data.view(data.shape[0], -1)
|
||||
return self.net(x), x
|
||||
|
||||
def train_loss(self, data):
|
||||
criterion = nn.MSELoss()
|
||||
y_hat, y = self.forward(data)
|
||||
loss = criterion(y_hat, y)
|
||||
return loss
|
||||
|
||||
def test_loss(self, data):
|
||||
y_hat, y = self.forward(data)
|
||||
preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
|
||||
return preds
|
||||
|
||||
def init_weights(self):
|
||||
def _weight_init(m):
|
||||
if hasattr(m, 'weight'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user