if __name__ == '__main__': import numpy as np import random from tqdm import tqdm from cfg import * from mimii import MIMII from models.ae import AE, SubSpecCAE import pickle import torch.optim as optim from models.layers import Subspectrogram def train(dataset_path, machine_id, band, norm='batch', loss_fn='mse', seed=42): torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True random.seed(seed) print(f'Training on {dataset_path.name}') mimii = MIMII(dataset_path=dataset_path, machine_id=machine_id) mimii.to(DEVICE) #mimii.preprocess(n_fft=1024, hop_length=256, n_mels=80, center=False, power=2.0) # 80 x 80 tfms = Subspectrogram(SUB_SPEC_HEIGT, SUB_SPEC_HOP) dl = mimii.train_dataloader( segment_len=NUM_SEGMENTS, hop_len=NUM_SEGMENT_HOPS, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True, transform=tfms ) model = SubSpecCAE(norm=norm, loss_fn=loss_fn, band=band).to(DEVICE) model.init_weights() # print(model(torch.randn(128, 1, 20, 80).to(DEVICE)).shape) optimizer = optim.Adam(model.parameters(), lr=0.0005) for epoch in range(NUM_EPOCHS): print(f'EPOCH #{epoch+1}') losses = [] for batch in tqdm(dl): data, labels = batch data = data.to(DEVICE) # torch.Size([128, 4, 20, 80]) batch x subs_specs x height x width loss = model.train_loss(data) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) print(f'Loss: {np.mean(losses)}') auc = mimii.evaluate_model(model, NUM_SEGMENTS, NUM_SEGMENTS, transform=tfms) print(f'AUC: {auc}, Machine: {machine_id}, Band: {band}, Norm: {norm}, Seed: {seed}') return auc loss_fn = 'mse' results = [] for norm in ['batch']: for seed in SEEDS: for dataset_path in ALL_DATASET_PATHS: if '-6_dB' in dataset_path.name: for machine_id in [4]: for band in range(7): auc = train(dataset_path, machine_id, band, norm, loss_fn, seed) results.append([dataset_path.name, machine_id, seed, band, norm, auc]) with open(f'results2_hard_{norm}_{loss_fn}.pkl', 'wb') as f: pickle.dump(results, f)