63 lines
2.6 KiB
Python
63 lines
2.6 KiB
Python
from copy import deepcopy
|
|
from pathlib import Path
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from datasets.urban_8k import UrbanSound8K
|
|
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_train, DATA_OPTION_devel
|
|
from ml_lib.utils.tools import add_argparse_args
|
|
|
|
|
|
class Urban8KLibrosaDatamodule(_BaseDataModule):
|
|
|
|
def __init__(self, batch_size, num_worker, data_root, sr, n_mels, n_fft, hop_length,
|
|
sampler=None, val_fold=9,
|
|
random_apply_chance=0.5, target_mel_length_in_seconds=1,
|
|
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, **kwargs):
|
|
super(Urban8KLibrosaDatamodule, self).__init__()
|
|
self.batch_size = batch_size
|
|
self.num_worker = num_worker
|
|
|
|
self.val_fold = val_fold
|
|
|
|
self.kwargs = kwargs
|
|
self.kwargs.update(data_root=data_root, num_worker=num_worker,
|
|
sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, sampler=sampler,
|
|
random_apply_chance=random_apply_chance,
|
|
target_mel_length_in_seconds=target_mel_length_in_seconds,
|
|
loudness_ratio=loudness_ratio, shift_ratio=shift_ratio, noise_ratio=noise_ratio,
|
|
mask_ratio=mask_ratio)
|
|
|
|
@classmethod
|
|
def add_argparse_args(cls, parent_parser):
|
|
return add_argparse_args(UrbanSound8K, parent_parser)
|
|
|
|
@classmethod
|
|
def from_argparse_args(cls, args, **kwargs):
|
|
val_fold = kwargs.get('val_fold', 10)
|
|
kwargs.update(val_fold=val_fold)
|
|
return super(Urban8KLibrosaDatamodule, cls).from_argparse_args(args, **kwargs)
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
|
|
batch_size=self.batch_size, pin_memory=True,
|
|
num_workers=self.num_worker)
|
|
|
|
# Validation Dataloader
|
|
def val_dataloader(self):
|
|
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
|
|
batch_size=self.batch_size, num_workers=self.num_worker)
|
|
|
|
def prepare_data(self, stag=None):
|
|
# Train Datasset
|
|
self.datasets[DATA_OPTION_train] = UrbanSound8K(fold=[x for x in list(range(1, 11)) if x != self.val_fold],
|
|
**self.kwargs)
|
|
# Devel Datasset
|
|
self.datasets[DATA_OPTION_devel] = UrbanSound8K(fold=self.val_fold, **self.kwargs)
|
|
|
|
def manual_setup(self):
|
|
UrbanSound8K(fold=[x for x in list(range(1, 11))], rebuild=True, **self.kwargs)
|
|
|
|
|
|
|