81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
if __name__ == '__main__':
|
|
import numpy as np
|
|
import random
|
|
from tqdm import tqdm
|
|
from cfg import *
|
|
from mimii import MIMII
|
|
from models.ae import AE, SubSpecCAE
|
|
import pickle
|
|
import torch.optim as optim
|
|
from models.layers import Subspectrogram
|
|
|
|
def train(dataset_path, machine_id, band, norm, seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
random.seed(seed)
|
|
|
|
print(f'Training on {dataset_path.name}')
|
|
mimii = MIMII(dataset_path=dataset_path, machine_id=machine_id)
|
|
mimii.to(DEVICE)
|
|
#mimii.preprocess(n_fft=1024, hop_length=256, n_mels=80, center=False, power=2.0) # 80 x 80
|
|
tfms = Subspectrogram(SUB_SPEC_HEIGT, SUB_SPEC_HOP)
|
|
|
|
dl = mimii.train_dataloader(
|
|
segment_len=NUM_SEGMENTS,
|
|
hop_len=NUM_SEGMENT_HOPS,
|
|
batch_size=BATCH_SIZE,
|
|
num_workers=NUM_WORKERS,
|
|
shuffle=True,
|
|
transform=tfms
|
|
)
|
|
|
|
model = SubSpecCAE(norm=norm, band=band).to(DEVICE)
|
|
model.init_weights()
|
|
|
|
# print(model(torch.randn(128, 1, 20, 80).to(DEVICE)).shape)
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
|
|
for epoch in range(NUM_EPOCHS):
|
|
print(f'EPOCH #{epoch+1}')
|
|
losses = []
|
|
for batch in tqdm(dl):
|
|
data, labels = batch
|
|
data = data.to(DEVICE) # torch.Size([128, 4, 20, 80]) batch x subs_specs x height x width
|
|
|
|
loss = model.train_loss(data)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
losses.append(loss.item())
|
|
print(f'Loss: {np.mean(losses)}')
|
|
|
|
auc = mimii.evaluate_model(model, NUM_SEGMENTS, NUM_SEGMENTS, transform=tfms)
|
|
print(f'AUC: {auc}, Machine: {machine_id}, Band: {band}, Norm: {norm}, Seed: {seed}')
|
|
return auc
|
|
|
|
|
|
results = []
|
|
for norm in ('instance', 'batch'):
|
|
for seed in SEEDS:
|
|
for dataset_path in ALL_DATASET_PATHS:
|
|
for machine_id in [0, 2, 4, 6]:
|
|
for band in range(7):
|
|
auc = train(dataset_path, machine_id, band, norm, seed)
|
|
results.append([dataset_path.name, machine_id, seed, band, norm, auc])
|
|
with open(f'results_{norm}.pkl', 'wb') as f:
|
|
pickle.dump(results, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|