bug in metric calculation
This commit is contained in:
@@ -32,7 +32,7 @@ class CompareBase(_BaseDataModule):
|
||||
|
||||
@property
|
||||
def mel_folder(self):
|
||||
return self.root / 'mel_folder'
|
||||
return Path(f'{self.root}_mel_folder')
|
||||
|
||||
@property
|
||||
def wav_folder(self):
|
||||
@@ -58,7 +58,10 @@ class CompareBase(_BaseDataModule):
|
||||
self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||
|
||||
# Utility
|
||||
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
self.utility_transforms = Compose([
|
||||
NormalizeLocal(),
|
||||
ToTensor()
|
||||
])
|
||||
|
||||
# Data Augmentations
|
||||
self.random_apply_chance = random_apply_chance
|
||||
@@ -85,8 +88,11 @@ class CompareBase(_BaseDataModule):
|
||||
batch_size=self.batch_size, pin_memory=False,
|
||||
num_workers=self.num_worker)
|
||||
|
||||
def _build_subdataset(self, row, build=False):
|
||||
def _build_subdataset(self, row, build=False, data_option=None):
|
||||
slice_file_name, class_name = row.strip().split(',')
|
||||
if data_option is not None:
|
||||
if data_option not in slice_file_name:
|
||||
return None, -1, 'no_file'
|
||||
class_id = self.class_names.get(class_name, -1)
|
||||
audio_file_path = self.wav_folder / slice_file_name
|
||||
|
||||
@@ -96,53 +102,54 @@ class CompareBase(_BaseDataModule):
|
||||
kwargs.update(mel_augmentations=self.utility_transforms)
|
||||
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
|
||||
|
||||
|
||||
kwargs.update(sample_segment_len=self.sample_segment_length, sample_hop_len=self.sample_segment_length//2)
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
|
||||
if build:
|
||||
assert mel_dataset.build_mel()
|
||||
return mel_dataset, class_id, slice_file_name
|
||||
|
||||
def manual_setup(self, stag=None):
|
||||
def manual_setup(self):
|
||||
datasets = dict()
|
||||
for data_option in data_options:
|
||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
all_rows = list(f)
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
dataset = list()
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
dataset.append(sub_dataset[0])
|
||||
with (Path(self.root) / 'lab' / 'labels.csv') as label_csv_file:
|
||||
if label_csv_file.exists():
|
||||
lab_file = label_csv_file.name
|
||||
else:
|
||||
lab_file = None
|
||||
|
||||
for data_option in data_options:
|
||||
if lab_file is not None:
|
||||
if any([x in lab_file for x in data_options]):
|
||||
lab_file = f'{data_option}.csv'
|
||||
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
print(f'{data_option}-dataset prepared.')
|
||||
self.datasets = datasets
|
||||
return datasets
|
||||
|
||||
def prepare_data(self, *args, rebuild=False, **kwargs):
|
||||
def prepare_data(self, *args, rebuild=False, subsets=None, **kwargs):
|
||||
datasets = dict()
|
||||
samplers = dict()
|
||||
weights = dict()
|
||||
|
||||
for data_option in data_options:
|
||||
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
all_rows = list(f)
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
dataset = list()
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
with (Path(self.root) / 'lab' / 'labels.csv') as label_csv_file:
|
||||
if label_csv_file.exists():
|
||||
lab_file = label_csv_file.name
|
||||
else:
|
||||
lab_file = None
|
||||
|
||||
for data_option in data_options:
|
||||
if subsets is not None:
|
||||
if data_option not in subsets:
|
||||
print(f'{data_option} skipped...')
|
||||
continue
|
||||
|
||||
if lab_file is not None:
|
||||
if any([x in lab_file for x in data_options]):
|
||||
lab_file = f'{data_option}.csv'
|
||||
|
||||
dataset = self._load_from_file(lab_file, data_option, rebuild=rebuild)
|
||||
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(rebuild, len(all_rows))),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
dataset.append(sub_dataset[0])
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
print(f'{data_option}-dataset set up!')
|
||||
|
||||
@@ -172,6 +179,27 @@ class CompareBase(_BaseDataModule):
|
||||
print(f'Dataset {self.__class__.__name__} setup done.')
|
||||
return datasets
|
||||
|
||||
def _load_from_file(self, lab_file, data_option, rebuild=False):
|
||||
with open(Path(self.root) / 'lab' / lab_file, mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
all_rows = list(f)
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
dataset = list()
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset,
|
||||
zip(all_rows,
|
||||
repeat(rebuild, len(all_rows)),
|
||||
repeat(data_option, len(all_rows))
|
||||
),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
if sub_dataset[0] is not None:
|
||||
dataset.append(sub_dataset[0])
|
||||
return dataset
|
||||
|
||||
def purge(self):
|
||||
import shutil
|
||||
|
||||
|
||||
19
datasets/mask_librosa_datamodule.py
Normal file
19
datasets/mask_librosa_datamodule.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from datasets.compare_base import CompareBase
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
|
||||
class MaskLibrosaDatamodule(CompareBase):
|
||||
|
||||
class_names = ['mask', 'clear']
|
||||
sub_dataset_name = 'ComParE2020_Mask'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MaskLibrosaDatamodule, self).__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(CompareBase, parent_parser)
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name)
|
||||
Reference in New Issue
Block a user