63 lines
2.6 KiB
Python

from copy import deepcopy
from pathlib import Path
from torch.utils.data import DataLoader
from datasets.urban_8k import UrbanSound8K
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.tools import add_argparse_args
class Urban8KLibrosaDatamodule(_BaseDataModule):
def __init__(self, batch_size, num_worker, data_root, sr, n_mels, n_fft, hop_length,
sampler=None, val_fold=9,
random_apply_chance=0.5, target_mel_length_in_seconds=1,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, **kwargs):
super(Urban8KLibrosaDatamodule, self).__init__()
self.batch_size = batch_size
self.num_worker = num_worker
self.val_fold = val_fold
self.kwargs = kwargs
self.kwargs.update(data_root=data_root, num_worker=num_worker,
sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, sampler=sampler,
random_apply_chance=random_apply_chance,
target_mel_length_in_seconds=target_mel_length_in_seconds,
loudness_ratio=loudness_ratio, shift_ratio=shift_ratio, noise_ratio=noise_ratio,
mask_ratio=mask_ratio)
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(UrbanSound8K, parent_parser)
@classmethod
def from_argparse_args(cls, args, **kwargs):
val_fold = kwargs.get('val_fold', 10)
kwargs.update(val_fold=val_fold)
return super(Urban8KLibrosaDatamodule, cls).from_argparse_args(args, **kwargs)
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
batch_size=self.batch_size, pin_memory=True,
num_workers=self.num_worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
batch_size=self.batch_size, num_workers=self.num_worker)
def prepare_data(self, stag=None):
# Train Datasset
self.datasets[DATA_OPTION_train] = UrbanSound8K(fold=[x for x in list(range(1, 11)) if x != self.val_fold],
**self.kwargs)
# Devel Datasset
self.datasets[DATA_OPTION_devel] = UrbanSound8K(fold=self.val_fold, **self.kwargs)
def manual_setup(self):
UrbanSound8K(fold=[x for x in list(range(1, 11))], rebuild=True, **self.kwargs)