ResidualModule and New Parameters, Speed Manipulation

This commit is contained in:
Si11ium 2020-05-12 12:37:26 +02:00
parent 3fbc98dfa3
commit 28bfcfdce3
8 changed files with 181 additions and 78 deletions

View File

@ -25,37 +25,40 @@ main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True,
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=62, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
# Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.2, help="")
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.4, help="")
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="")
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="")
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0.5, help="")
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0.7, help="")
# Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=300, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="")
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="ConvClassifier", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="")
main_arg_parser.add_argument("--model_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="CC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 256, 16]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.0, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
# Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")

View File

@ -19,7 +19,8 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False):
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=True):
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
@ -29,11 +30,11 @@ class BinaryMasksDataset(Dataset):
self.mixup = mixup
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._transforms = transforms or F_x()
self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys()))
self._mel_folder = self.data_root / 'mel'
self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
@ -45,6 +46,8 @@ class BinaryMasksDataset(Dataset):
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch:
labeldict.update({f'X_{key}': val for key, val in labeldict.items()})
return labeldict
def __len__(self):
@ -52,7 +55,7 @@ class BinaryMasksDataset(Dataset):
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename + '.wav'))
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
@ -65,8 +68,9 @@ class BinaryMasksDataset(Dataset):
is_mixed = item >= len(self._labels)
if is_mixed:
item = item - len(self._labels)
key = self._wav_files[item]
filename = key[:-4]
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]

View File

@ -1,11 +1,10 @@
from argparse import Namespace
from torch import nn
from torch.nn import ModuleDict, ModuleList
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter,
HorizontalMerger, F_x)
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin)
@ -30,44 +29,39 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction,
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 4
# Modules
# =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.conv_dict = ModuleDict()
self.conv_dict.update({f"conv_1_{band_section}":
ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_2_{band_section}":
ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1,
**self.params.module_kwargs) for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_3_{band_section}":
ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1,
**self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
k = 3
self.band_list = ModuleList()
for band in range(self.n_band_sections):
last_shape = self.split.shape
conv_list = ModuleList()
for filters in self.conv_filters:
conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs))
last_shape = conv_list[-1].shape
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
# last_shape = self.conv_list[-1].shape
self.band_list.append(conv_list)
self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections)
self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections)
self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
self.full_3 = LinearModule(self.full_2.shape, self.full_2.out_features // 2, **self.params.module_kwargs)
self.full_3 = LinearModule(self.full_2.shape, self.full_2.shape // 2, **self.params.module_kwargs)
self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid)
def forward(self, batch, **kwargs):
tensors = self.split(batch)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_1_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_2_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
for conv in convs:
tensor = conv(tensor)
tensors[idx] = tensor
tensor = self.merge(tensors)
tensor = self.full_1(tensor)

View File

@ -22,24 +22,32 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
batch_x, batch_y = batch_xy
y = self(batch_x)
y, bands_y = y.main_out, y.bands
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y]
return_dict = {f'band_{band_idx}_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
overall_loss = self.criterion(y, batch_y)
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
return_dict.update(loss=combined_loss, overall_loss=overall_loss)
last_bce_loss = self.bce_loss(y, batch_y)
return_dict.update(last_bce_loss=last_bce_loss)
bands_y_losses.append(last_bce_loss)
combined_loss = torch.stack(bands_y_losses).mean()
return_dict.update(loss=combined_loss)
return return_dict
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
y, bands_y = y.main_out, y.bands
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
bands_y_losses = [self.bce_loss(band_y, batch_y) for band_y in bands_y]
return_dict = {f'band_{band_idx}_val_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
overall_loss = self.criterion(y, batch_y)
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
val_abs_loss = self.absolute_loss(y, batch_y)
return_dict.update(val_bce_loss=combined_loss, val_abs_loss=val_abs_loss,
last_bce_loss = self.bce_loss(y, batch_y)
return_dict.update(last_bce_loss=last_bce_loss)
bands_y_losses.append(last_bce_loss)
combined_loss = torch.stack(bands_y_losses).mean()
return_dict.update(val_bce_loss=combined_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
return return_dict
@ -56,7 +64,6 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 4
k = 3 # Base Kernel Value
@ -69,7 +76,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
last_shape = self.split.shape
conv_list = ModuleList()
for filters in self.conv_filters:
conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2),
conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(1, 1),
**self.params.module_kwargs))
last_shape = conv_list[-1].shape
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))

View File

@ -29,23 +29,23 @@ class ConvClassifier(BinaryMaskDatasetFunction,
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
# Modules with Parameters
self.conv_list = ModuleList()
last_shape = self.in_shape
k = 3 # Base Kernel Value
for filters in self.conv_filters:
self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs))
self.conv_list.append(ConvModule(last_shape, filters, (k,k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
# last_shape = self.conv_list[-1].shape
self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias)
self.full_3 = LinearModule(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias)
self.full_1 = LinearModule(self.conv_list[-1].shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
self.full_3 = LinearModule(self.full_2.shape, self.full_2.shape // 2, **self.params.module_kwargs)
self.full_out = LinearModule(self.full_3.out_features, 1, bias=self.params.bias, activation=nn.Sigmoid)
self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid)
def forward(self, batch, **kwargs):
tensor = batch

View File

@ -0,0 +1,64 @@
from argparse import Namespace
from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
from ml_lib.modules.utils import LightningBaseModule
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin)
class ResidualConvClassifier(BinaryMaskDatasetFunction,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(ResidualConvClassifier, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
# Modules with Parameters
self.conv_list = ModuleList()
last_shape = self.in_shape
k = 3 # Base Kernel Value
conv_module_params = self.params.module_kwargs
conv_module_params.update(conv_kernel=(k, k), conv_stride=(1, 1), conv_padding=1)
self.conv_list.append(ConvModule(last_shape, self.conv_filters[0], (k, k), conv_stride=(2, 2), conv_padding=1,
**self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
for filters in self.conv_filters:
conv_module_params.update(conv_filters=filters)
self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params))
last_shape = self.conv_list[-1].shape
self.conv_list.append(ConvModule(last_shape, filters, (k, k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
self.full_1 = LinearModule(self.conv_list[-1].shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
self.full_3 = LinearModule(self.full_2.shape, self.full_2.shape // 2, **self.params.module_kwargs)
self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid)
def forward(self, batch, **kwargs):
tensor = batch
for conv in self.conv_list:
tensor = conv(tensor)
tensor = self.full_1(tensor)
tensor = self.full_2(tensor)
tensor = self.full_3(tensor)
tensor = self.full_out(tensor)
return Namespace(main_out=tensor)

View File

@ -3,6 +3,7 @@ from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
from models.residual_conv_classifier import ResidualConvClassifier
class MConfig(Config):
@ -11,7 +12,13 @@ class MConfig(Config):
@property
def _model_map(self):
return dict(ConvClassifier=ConvClassifier,
CC=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BCC=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
BCMC=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
E=Ensemble,
ResidualConvClassifier=ResidualConvClassifier,
RCC=ResidualConvClassifier
)

View File

@ -6,13 +6,14 @@ from argparse import Namespace
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, RandomSampler
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
@ -24,17 +25,19 @@ class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=0.04)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
@ -43,14 +46,19 @@ class BaseOptimizerMixin:
class BaseTrainMixin:
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.criterion(y, batch_y)
return dict(loss=loss)
bce_loss = self.bce_loss(y, batch_y)
return dict(loss=bce_loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
@ -61,18 +69,20 @@ class BaseTrainMixin:
class BaseValMixin:
absolute_loss = L1Loss()
absolute_loss = nn.L1Loss()
nll_loss = nn.NLLLoss()
bce_loss = nn.BCELoss()
def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y)
val_abs_loss = self.absolute_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss, val_abs_loss=val_abs_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
val_bce_loss = self.bce_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
@ -103,6 +113,12 @@ class BinaryMaskDatasetFunction:
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms_train = Compose([
# Audio to Mel Transformations
Speed(speed_factor=self.params.speed_factor, max_ratio=self.params.speed_ratio),
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length),
MelToImage()])
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
@ -112,25 +128,28 @@ class BinaryMaskDatasetFunction:
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio),
], p=0.6),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms),
mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=val_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
# TEST DATASET
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms),
)
@ -144,18 +163,23 @@ class BaseDataloadersMixin(ABC):
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
assert isinstance(self, LightningBaseModule)
# sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
sampler = None
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
assert isinstance(self, LightningBaseModule)
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size, num_workers=self.params.worker)