big update

This commit is contained in:
Robert Müller
2020-04-06 14:46:26 +02:00
parent 0f325676e5
commit 482f45df87
17 changed files with 1027 additions and 32 deletions

24
main.py
View File

@ -9,7 +9,8 @@ if __name__ == '__main__':
import torch.optim as optim
from models.layers import Subspectrogram
def train(dataset_path, machine_id, band, norm, seed):
def train(dataset_path, machine_id, band, norm='batch', loss_fn='mse', seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
@ -32,12 +33,12 @@ if __name__ == '__main__':
transform=tfms
)
model = SubSpecCAE(norm=norm, band=band).to(DEVICE)
model = SubSpecCAE(norm=norm, loss_fn=loss_fn, band=band).to(DEVICE)
model.init_weights()
# print(model(torch.randn(128, 1, 20, 80).to(DEVICE)).shape)
optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
for epoch in range(NUM_EPOCHS):
@ -60,17 +61,18 @@ if __name__ == '__main__':
print(f'AUC: {auc}, Machine: {machine_id}, Band: {band}, Norm: {norm}, Seed: {seed}')
return auc
loss_fn = 'mse'
results = []
for norm in ('instance', 'batch'):
for norm in ['batch']:
for seed in SEEDS:
for dataset_path in ALL_DATASET_PATHS:
for machine_id in [0, 2, 4, 6]:
for band in range(7):
auc = train(dataset_path, machine_id, band, norm, seed)
results.append([dataset_path.name, machine_id, seed, band, norm, auc])
with open(f'results_{norm}.pkl', 'wb') as f:
pickle.dump(results, f)
if '-6_dB' in dataset_path.name:
for machine_id in [4]:
for band in range(7):
auc = train(dataset_path, machine_id, band, norm, loss_fn, seed)
results.append([dataset_path.name, machine_id, seed, band, norm, auc])
with open(f'results2_hard_{norm}_{loss_fn}.pkl', 'wb') as f:
pickle.dump(results, f)