from copy import deepcopy from pathlib import Path from torch.utils.data import DataLoader from datasets.urban_8k import UrbanSound8K from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_train, DATA_OPTION_devel from ml_lib.utils.tools import add_argparse_args class Urban8KLibrosaDatamodule(_BaseDataModule): def __init__(self, batch_size, num_worker, data_root, sr, n_mels, n_fft, hop_length, sampler=None, val_fold=9, random_apply_chance=0.5, target_mel_length_in_seconds=1, loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, **kwargs): super(Urban8KLibrosaDatamodule, self).__init__() self.batch_size = batch_size self.num_worker = num_worker self.val_fold = val_fold self.kwargs = kwargs self.kwargs.update(data_root=data_root, num_worker=num_worker, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, sampler=sampler, random_apply_chance=random_apply_chance, target_mel_length_in_seconds=target_mel_length_in_seconds, loudness_ratio=loudness_ratio, shift_ratio=shift_ratio, noise_ratio=noise_ratio, mask_ratio=mask_ratio) @classmethod def add_argparse_args(cls, parent_parser): return add_argparse_args(UrbanSound8K, parent_parser) @classmethod def from_argparse_args(cls, args, **kwargs): val_fold = kwargs.get('val_fold', 10) kwargs.update(val_fold=val_fold) return super(Urban8KLibrosaDatamodule, cls).from_argparse_args(args, **kwargs) def train_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_worker) # Validation Dataloader def val_dataloader(self): return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True, batch_size=self.batch_size, num_workers=self.num_worker) def prepare_data(self, stag=None): # Train Datasset self.datasets[DATA_OPTION_train] = UrbanSound8K(fold=[x for x in list(range(1, 11)) if x != self.val_fold], **self.kwargs) # Devel Datasset self.datasets[DATA_OPTION_devel] = UrbanSound8K(fold=self.val_fold, **self.kwargs) def manual_setup(self): UrbanSound8K(fold=[x for x in list(range(1, 11))], rebuild=True, **self.kwargs)