big update

This commit is contained in:
Robert Müller
2020-04-06 14:46:26 +02:00
parent 0f325676e5
commit 482f45df87
17 changed files with 1027 additions and 32 deletions

@ -71,11 +71,7 @@ class MIMII(object):
)
return DataLoader(ConcatDataset(ds), **kwargs)
def test_dataloader(self, *args, **kwargs):
raise NotImplementedError('test_dataloader is not supported')
def evaluate_model(self, f, segment_len=20, hop_len=5, transform=None):
f.eval()
def test_datasets(self, segment_len=20, hop_len=5, transform=None):
datasets = []
for p, l in zip(self.test_paths, self.test_labels):
datasets.append(
@ -83,6 +79,11 @@ class MIMII(object):
segment_len=segment_len,
hop=hop_len, transform=transform)
)
return datasets
def evaluate_model(self, f, segment_len=20, hop_len=5, transform=None):
f.eval()
datasets = self.test_datasets(segment_len, hop_len, transform)
y_true, y_score = [], []
with torch.no_grad():
for dataset in tqdm(datasets):