big update
This commit is contained in:
24
main.py
24
main.py
@ -9,7 +9,8 @@ if __name__ == '__main__':
|
||||
import torch.optim as optim
|
||||
from models.layers import Subspectrogram
|
||||
|
||||
def train(dataset_path, machine_id, band, norm, seed):
|
||||
|
||||
def train(dataset_path, machine_id, band, norm='batch', loss_fn='mse', seed=42):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
@ -32,12 +33,12 @@ if __name__ == '__main__':
|
||||
transform=tfms
|
||||
)
|
||||
|
||||
model = SubSpecCAE(norm=norm, band=band).to(DEVICE)
|
||||
model = SubSpecCAE(norm=norm, loss_fn=loss_fn, band=band).to(DEVICE)
|
||||
model.init_weights()
|
||||
|
||||
# print(model(torch.randn(128, 1, 20, 80).to(DEVICE)).shape)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.0005)
|
||||
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
@ -60,17 +61,18 @@ if __name__ == '__main__':
|
||||
print(f'AUC: {auc}, Machine: {machine_id}, Band: {band}, Norm: {norm}, Seed: {seed}')
|
||||
return auc
|
||||
|
||||
|
||||
loss_fn = 'mse'
|
||||
results = []
|
||||
for norm in ('instance', 'batch'):
|
||||
for norm in ['batch']:
|
||||
for seed in SEEDS:
|
||||
for dataset_path in ALL_DATASET_PATHS:
|
||||
for machine_id in [0, 2, 4, 6]:
|
||||
for band in range(7):
|
||||
auc = train(dataset_path, machine_id, band, norm, seed)
|
||||
results.append([dataset_path.name, machine_id, seed, band, norm, auc])
|
||||
with open(f'results_{norm}.pkl', 'wb') as f:
|
||||
pickle.dump(results, f)
|
||||
if '-6_dB' in dataset_path.name:
|
||||
for machine_id in [4]:
|
||||
for band in range(7):
|
||||
auc = train(dataset_path, machine_id, band, norm, loss_fn, seed)
|
||||
results.append([dataset_path.name, machine_id, seed, band, norm, auc])
|
||||
with open(f'results2_hard_{norm}_{loss_fn}.pkl', 'wb') as f:
|
||||
pickle.dump(results, f)
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user