diff --git a/_paramters.py b/_paramters.py index f8437ac..7a79b94 100644 --- a/_paramters.py +++ b/_paramters.py @@ -25,37 +25,40 @@ main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") -main_arg_parser.add_argument("--data_hop_length", type=int, default=62, help="") +main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") # Transformation Parameters -main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.2, help="") -main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.4, help="") -main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="") +main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") +main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="") +main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") +main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") +main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.5, help="") +main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="") # Training Parameters main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? -main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=300, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="") +main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") # Model Parameters -main_arg_parser.add_argument("--model_type", type=str, default="ConvClassifier", help="") -main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="") +main_arg_parser.add_argument("--model_type", type=str, default="CC", help="") +main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128, 64]", help="") +main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 256, 16]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") -main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="") +main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--model_dropout", type=float, default=0.0, help="") +main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="") # Project Parameters main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index 1843238..20a2072 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -19,7 +19,8 @@ class BinaryMasksDataset(Dataset): def sample_shape(self): return self[0][0].shape - def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False): + def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=True): + self.stretch = stretch_dataset assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' super(BinaryMasksDataset, self).__init__() @@ -29,11 +30,11 @@ class BinaryMasksDataset(Dataset): self.mixup = mixup self.container_ext = '.pik' self._mel_transform = mel_transforms - self._transforms = transforms or F_x() self._labels = self._build_labels() self._wav_folder = self.data_root / 'wav' self._wav_files = list(sorted(self._labels.keys())) self._mel_folder = self.data_root / 'mel' + self._transforms = transforms or F_x(in_shape=None) def _build_labels(self): with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: @@ -45,6 +46,8 @@ class BinaryMasksDataset(Dataset): continue filename, label = row.strip().split(',') labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename + if self.stretch: + labeldict.update({f'X_{key}': val for key, val in labeldict.items()}) return labeldict def __len__(self): @@ -52,7 +55,7 @@ class BinaryMasksDataset(Dataset): def _compute_or_retrieve(self, filename): if not (self._mel_folder / (filename + self.container_ext)).exists(): - raw_sample, sr = librosa.core.load(self._wav_folder / (filename + '.wav')) + raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav')) mel_sample = self._mel_transform(raw_sample) self._mel_folder.mkdir(exist_ok=True, parents=True) with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: @@ -65,8 +68,9 @@ class BinaryMasksDataset(Dataset): is_mixed = item >= len(self._labels) if is_mixed: item = item - len(self._labels) - key = self._wav_files[item] - filename = key[:-4] + + key: str = list(self._labels.keys())[item] + filename = key.replace('.wav', '') mel_sample = self._compute_or_retrieve(filename) label = self._labels[key] diff --git a/models/bandwise_conv_classifier.py b/models/bandwise_conv_classifier.py index 216fc3b..1a1ba43 100644 --- a/models/bandwise_conv_classifier.py +++ b/models/bandwise_conv_classifier.py @@ -1,11 +1,10 @@ from argparse import Namespace from torch import nn -from torch.nn import ModuleDict, ModuleList +from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule -from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, - HorizontalMerger, F_x) +from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, HorizontalMerger) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, BaseDataloadersMixin) @@ -30,44 +29,39 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction, # Additional parameters self.in_shape = self.dataset.train_dataset.sample_shape self.conv_filters = self.params.filters - self.criterion = nn.BCELoss() self.n_band_sections = 4 # Modules # ============================================================================= self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) - self.conv_dict = ModuleDict() - self.conv_dict.update({f"conv_1_{band_section}": - ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs) - for band_section in range(self.n_band_sections)} - ) - self.conv_dict.update({f"conv_2_{band_section}": - ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1, - **self.params.module_kwargs) for band_section in range(self.n_band_sections)} - ) - self.conv_dict.update({f"conv_3_{band_section}": - ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1, - **self.params.module_kwargs) - for band_section in range(self.n_band_sections)} - ) + k = 3 + self.band_list = ModuleList() + for band in range(self.n_band_sections): + last_shape = self.split.shape + conv_list = ModuleList() + for filters in self.conv_filters: + conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(2, 2), conv_padding=2, + **self.params.module_kwargs)) + last_shape = conv_list[-1].shape + # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) + # last_shape = self.conv_list[-1].shape + self.band_list.append(conv_list) - self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections) + self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections) - self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs) + self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs) self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) - self.full_3 = LinearModule(self.full_2.shape, self.full_2.out_features // 2, **self.params.module_kwargs) + self.full_3 = LinearModule(self.full_2.shape, self.full_2.shape // 2, **self.params.module_kwargs) self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) def forward(self, batch, **kwargs): tensors = self.split(batch) - for idx, tensor in enumerate(tensors): - tensors[idx] = self.conv_dict[f"conv_1_{idx}"](tensor) - for idx, tensor in enumerate(tensors): - tensors[idx] = self.conv_dict[f"conv_2_{idx}"](tensor) - for idx, tensor in enumerate(tensors): - tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor) + for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)): + for conv in convs: + tensor = conv(tensor) + tensors[idx] = tensor tensor = self.merge(tensors) tensor = self.full_1(tensor) diff --git a/models/bandwise_conv_multihead_classifier.py b/models/bandwise_conv_multihead_classifier.py index 8271136..d104a86 100644 --- a/models/bandwise_conv_multihead_classifier.py +++ b/models/bandwise_conv_multihead_classifier.py @@ -22,24 +22,32 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction, batch_x, batch_y = batch_xy y = self(batch_x) y, bands_y = y.main_out, y.bands - bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y] + bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y] return_dict = {f'band_{band_idx}_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)} - overall_loss = self.criterion(y, batch_y) - combined_loss = overall_loss + torch.stack(bands_y_losses).sum() - return_dict.update(loss=combined_loss, overall_loss=overall_loss) + + last_bce_loss = self.bce_loss(y, batch_y) + return_dict.update(last_bce_loss=last_bce_loss) + + bands_y_losses.append(last_bce_loss) + combined_loss = torch.stack(bands_y_losses).mean() + + return_dict.update(loss=combined_loss) return return_dict def validation_step(self, batch_xy, batch_idx, *args, **kwargs): batch_x, batch_y = batch_xy y = self(batch_x) y, bands_y = y.main_out, y.bands - bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y] + bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y] return_dict = {f'band_{band_idx}_val_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)} - overall_loss = self.criterion(y, batch_y) - combined_loss = overall_loss + torch.stack(bands_y_losses).sum() - val_abs_loss = self.absolute_loss(y, batch_y) - return_dict.update(val_bce_loss=combined_loss, val_abs_loss=val_abs_loss, + last_bce_loss = self.bce_loss(y, batch_y) + return_dict.update(last_bce_loss=last_bce_loss) + + bands_y_losses.append(last_bce_loss) + combined_loss = torch.stack(bands_y_losses).mean() + + return_dict.update(val_bce_loss=combined_loss, batch_idx=batch_idx, y=y, batch_y=batch_y ) return return_dict @@ -56,7 +64,6 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction, # Additional parameters self.in_shape = self.dataset.train_dataset.sample_shape self.conv_filters = self.params.filters - self.criterion = nn.BCELoss() self.n_band_sections = 4 k = 3 # Base Kernel Value @@ -69,7 +76,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction, last_shape = self.split.shape conv_list = ModuleList() for filters in self.conv_filters: - conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2), + conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(1, 1), **self.params.module_kwargs)) last_shape = conv_list[-1].shape # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) diff --git a/models/conv_classifier.py b/models/conv_classifier.py index f16f799..ca4b92f 100644 --- a/models/conv_classifier.py +++ b/models/conv_classifier.py @@ -29,23 +29,23 @@ class ConvClassifier(BinaryMaskDatasetFunction, # Additional parameters self.in_shape = self.dataset.train_dataset.sample_shape self.conv_filters = self.params.filters - self.criterion = nn.BCELoss() # Modules with Parameters self.conv_list = ModuleList() last_shape = self.in_shape k = 3 # Base Kernel Value for filters in self.conv_filters: - self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs)) + self.conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(2, 2), conv_padding=2, + **self.params.module_kwargs)) last_shape = self.conv_list[-1].shape # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) # last_shape = self.conv_list[-1].shape - self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs) - self.full_2 = LinearModule(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias) - self.full_3 = LinearModule(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias) + self.full_1 = LinearModule(self.conv_list[-1].shape, self.params.lat_dim, **self.params.module_kwargs) + self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) + self.full_3 = LinearModule(self.full_2.shape, self.full_2.shape // 2, **self.params.module_kwargs) - self.full_out = LinearModule(self.full_3.out_features, 1, bias=self.params.bias, activation=nn.Sigmoid) + self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) def forward(self, batch, **kwargs): tensor = batch diff --git a/models/residual_conv_classifier.py b/models/residual_conv_classifier.py new file mode 100644 index 0000000..de8209b --- /dev/null +++ b/models/residual_conv_classifier.py @@ -0,0 +1,64 @@ +from argparse import Namespace + +from torch import nn +from torch.nn import ModuleList + +from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule +from ml_lib.modules.utils import LightningBaseModule +from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, + BaseDataloadersMixin) + + +class ResidualConvClassifier(BinaryMaskDatasetFunction, + BaseDataloadersMixin, + BaseTrainMixin, + BaseValMixin, + BaseOptimizerMixin, + LightningBaseModule + ): + + def __init__(self, hparams): + super(ResidualConvClassifier, self).__init__(hparams) + + # Dataset + # ============================================================================= + self.dataset = self.build_dataset() + + # Model Paramters + # ============================================================================= + # Additional parameters + self.in_shape = self.dataset.train_dataset.sample_shape + self.conv_filters = self.params.filters + + # Modules with Parameters + self.conv_list = ModuleList() + last_shape = self.in_shape + k = 3 # Base Kernel Value + conv_module_params = self.params.module_kwargs + conv_module_params.update(conv_kernel=(k, k), conv_stride=(1, 1), conv_padding=1) + self.conv_list.append(ConvModule(last_shape, self.conv_filters[0], (k, k), conv_stride=(2, 2), conv_padding=1, + **self.params.module_kwargs)) + last_shape = self.conv_list[-1].shape + for filters in self.conv_filters: + conv_module_params.update(conv_filters=filters) + self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params)) + last_shape = self.conv_list[-1].shape + self.conv_list.append(ConvModule(last_shape, filters, (k, k), conv_stride=(2, 2), conv_padding=2, + **self.params.module_kwargs)) + last_shape = self.conv_list[-1].shape + + self.full_1 = LinearModule(self.conv_list[-1].shape, self.params.lat_dim, **self.params.module_kwargs) + self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) + self.full_3 = LinearModule(self.full_2.shape, self.full_2.shape // 2, **self.params.module_kwargs) + + self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) + + def forward(self, batch, **kwargs): + tensor = batch + for conv in self.conv_list: + tensor = conv(tensor) + tensor = self.full_1(tensor) + tensor = self.full_2(tensor) + tensor = self.full_3(tensor) + tensor = self.full_out(tensor) + return Namespace(main_out=tensor) diff --git a/util/config.py b/util/config.py index 0426ae9..14c1ac4 100644 --- a/util/config.py +++ b/util/config.py @@ -3,6 +3,7 @@ from models.conv_classifier import ConvClassifier from models.bandwise_conv_classifier import BandwiseConvClassifier from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier from models.ensemble import Ensemble +from models.residual_conv_classifier import ResidualConvClassifier class MConfig(Config): @@ -11,7 +12,13 @@ class MConfig(Config): @property def _model_map(self): return dict(ConvClassifier=ConvClassifier, + CC=ConvClassifier, BandwiseConvClassifier=BandwiseConvClassifier, + BCC=BandwiseConvClassifier, BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier, + BCMC=BandwiseConvMultiheadClassifier, Ensemble=Ensemble, + E=Ensemble, + ResidualConvClassifier=ResidualConvClassifier, + RCC=ResidualConvClassifier ) diff --git a/util/module_mixins.py b/util/module_mixins.py index 9854720..5a0c6bc 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -6,13 +6,14 @@ from argparse import Namespace import sklearn import torch import numpy as np -from torch.nn import L1Loss +from torch import nn from torch.optim import Adam -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, RandomSampler from torchcontrib.optim import SWA from torchvision.transforms import Compose, RandomApply -from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime +from ml_lib.audio_toolset.audio_augmentation import Speed +from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal from ml_lib.modules.utils import LightningBaseModule from ml_lib.utils.transforms import ToTensor @@ -24,17 +25,19 @@ class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) - opt = Adam(params=self.parameters(), lr=self.params.lr) + opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=0.04) if self.params.sto_weight_avg: opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) return opt def on_train_end(self): + assert isinstance(self, LightningBaseModule) for opt in self.trainer.optimizers: if isinstance(opt, SWA): opt.swap_swa_sgd() def on_epoch_end(self): + assert isinstance(self, LightningBaseModule) if self.params.opt_reset_interval: if self.current_epoch % self.params.opt_reset_interval == 0: for opt in self.trainer.optimizers: @@ -43,14 +46,19 @@ class BaseOptimizerMixin: class BaseTrainMixin: + absolute_loss = nn.L1Loss() + nll_loss = nn.NLLLoss() + bce_loss = nn.BCELoss() + def training_step(self, batch_xy, batch_nb, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - loss = self.criterion(y, batch_y) - return dict(loss=loss) + bce_loss = self.bce_loss(y, batch_y) + return dict(loss=bce_loss) def training_epoch_end(self, outputs): + assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] @@ -61,18 +69,20 @@ class BaseTrainMixin: class BaseValMixin: - absolute_loss = L1Loss() + absolute_loss = nn.L1Loss() + nll_loss = nn.NLLLoss() + bce_loss = nn.BCELoss() def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): + assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out - val_bce_loss = self.criterion(y, batch_y) - val_abs_loss = self.absolute_loss(y, batch_y) - return dict(val_bce_loss=val_bce_loss, val_abs_loss=val_abs_loss, - batch_idx=batch_idx, y=y, batch_y=batch_y - ) + val_bce_loss = self.bce_loss(y, batch_y) + return dict(val_bce_loss=val_bce_loss, + batch_idx=batch_idx, y=y, batch_y=batch_y) def validation_epoch_end(self, outputs, *args, **kwargs): + assert isinstance(self, LightningBaseModule) summary_dict = dict(log=dict()) for output_idx, output in enumerate(outputs): keys = list(output[0].keys()) @@ -103,6 +113,12 @@ class BinaryMaskDatasetFunction: # Dataset # ============================================================================= # Mel Transforms + mel_transforms_train = Compose([ + # Audio to Mel Transformations + Speed(speed_factor=self.params.speed_factor, max_ratio=self.params.speed_ratio), + AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, + hop_length=self.params.hop_length), + MelToImage()]) mel_transforms = Compose([ # Audio to Mel Transformations AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, @@ -112,25 +128,28 @@ class BinaryMaskDatasetFunction: RandomApply([ NoiseInjection(self.params.noise_ratio), LoudnessManipulator(self.params.loudness_ratio), - ShiftTime(self.params.shift_ratio)], p=0.5), + ShiftTime(self.params.shift_ratio), + MaskAug(self.params.mask_ratio), + ], p=0.6), # Utility NormalizeLocal(), ToTensor() ]) val_transforms = Compose([NormalizeLocal(), ToTensor()]) - # sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None - # Datasets from datasets.binar_masks import BinaryMasksDataset dataset = Namespace( **dict( + # TRAIN DATASET train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup, - mel_transforms=mel_transforms, transforms=aug_transforms), + mel_transforms=mel_transforms_train, transforms=aug_transforms), + # VALIDATION DATASET val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mel_transforms=mel_transforms, transforms=val_transforms), val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, mel_transforms=mel_transforms, transforms=val_transforms), + # TEST DATASET test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, mel_transforms=mel_transforms, transforms=val_transforms), ) @@ -144,18 +163,23 @@ class BaseDataloadersMixin(ABC): # ================================================================================ # Train Dataloader def train_dataloader(self): - return DataLoader(dataset=self.dataset.train_dataset, shuffle=True, + assert isinstance(self, LightningBaseModule) + # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) + sampler = None + return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler, batch_size=self.params.batch_size, num_workers=self.params.worker) # Test Dataloader def test_dataloader(self): + assert isinstance(self, LightningBaseModule) return DataLoader(dataset=self.dataset.test_dataset, shuffle=False, batch_size=self.params.batch_size, num_workers=self.params.worker) # Validation Dataloader def val_dataloader(self): + assert isinstance(self, LightningBaseModule) val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True, batch_size=self.params.batch_size, num_workers=self.params.worker)