Robert Müller 482f45df87 big update
2020-04-06 14:46:26 +02:00

83 lines
2.7 KiB
Python

if __name__ == '__main__':
import numpy as np
import random
from tqdm import tqdm
from cfg import *
from mimii import MIMII
from models.ae import AE, SubSpecCAE
import pickle
import torch.optim as optim
from models.layers import Subspectrogram
def train(dataset_path, machine_id, band, norm='batch', loss_fn='mse', seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
random.seed(seed)
print(f'Training on {dataset_path.name}')
mimii = MIMII(dataset_path=dataset_path, machine_id=machine_id)
mimii.to(DEVICE)
#mimii.preprocess(n_fft=1024, hop_length=256, n_mels=80, center=False, power=2.0) # 80 x 80
tfms = Subspectrogram(SUB_SPEC_HEIGT, SUB_SPEC_HOP)
dl = mimii.train_dataloader(
segment_len=NUM_SEGMENTS,
hop_len=NUM_SEGMENT_HOPS,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
shuffle=True,
transform=tfms
)
model = SubSpecCAE(norm=norm, loss_fn=loss_fn, band=band).to(DEVICE)
model.init_weights()
# print(model(torch.randn(128, 1, 20, 80).to(DEVICE)).shape)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
for epoch in range(NUM_EPOCHS):
print(f'EPOCH #{epoch+1}')
losses = []
for batch in tqdm(dl):
data, labels = batch
data = data.to(DEVICE) # torch.Size([128, 4, 20, 80]) batch x subs_specs x height x width
loss = model.train_loss(data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
print(f'Loss: {np.mean(losses)}')
auc = mimii.evaluate_model(model, NUM_SEGMENTS, NUM_SEGMENTS, transform=tfms)
print(f'AUC: {auc}, Machine: {machine_id}, Band: {band}, Norm: {norm}, Seed: {seed}')
return auc
loss_fn = 'mse'
results = []
for norm in ['batch']:
for seed in SEEDS:
for dataset_path in ALL_DATASET_PATHS:
if '-6_dB' in dataset_path.name:
for machine_id in [4]:
for band in range(7):
auc = train(dataset_path, machine_id, band, norm, loss_fn, seed)
results.append([dataset_path.name, machine_id, seed, band, norm, auc])
with open(f'results2_hard_{norm}_{loss_fn}.pkl', 'wb') as f:
pickle.dump(results, f)