from collections import defaultdict from abc import ABC from argparse import Namespace import sklearn import torch import numpy as np from torch import nn from torch.optim import Adam from torch.utils.data import DataLoader, RandomSampler from torchcontrib.optim import SWA from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal from ml_lib.modules.utils import LightningBaseModule from ml_lib.utils.transforms import ToTensor import variables as V class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7) if self.params.sto_weight_avg: opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) return opt def on_train_end(self): assert isinstance(self, LightningBaseModule) for opt in self.trainer.optimizers: if isinstance(opt, SWA): opt.swap_swa_sgd() def on_epoch_end(self): assert isinstance(self, LightningBaseModule) if self.params.opt_reset_interval: if self.current_epoch % self.params.opt_reset_interval == 0: for opt in self.trainer.optimizers: opt.state = defaultdict(dict) class BaseTrainMixin: absolute_loss = nn.L1Loss() nll_loss = nn.NLLLoss() bce_loss = nn.BCELoss() def training_step(self, batch_xy, batch_nb, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out bce_loss = self.bce_loss(y, batch_y) return dict(loss=bce_loss) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key}) return summary_dict class BaseValMixin: absolute_loss = nn.L1Loss() nll_loss = nn.NLLLoss() bce_loss = nn.BCELoss() def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out val_bce_loss = self.bce_loss(y, batch_y) return dict(val_bce_loss=val_bce_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def validation_epoch_end(self, outputs, *args, **kwargs): assert isinstance(self, LightningBaseModule) summary_dict = dict(log=dict()) for output_idx, output in enumerate(outputs): keys = list(output[0].keys()) ident = '' if output_idx == 0 else '_train' summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] for output in output])) for key in keys if 'loss' in key} ) # UnweightedAverageRecall y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy() y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy() y_pred = (y_pred >= 0.5).astype(np.float32) uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', sample_weight=None, zero_division='warn') summary_dict['log'].update({f'uar{ident}_score': uar_score}) return summary_dict class BinaryMaskDatasetMixin: def build_dataset(self): assert isinstance(self, LightningBaseModule) # Dataset # ============================================================================= # Mel Transforms mel_transforms_train = Compose([ # Audio to Mel Transformations Speed(speed_factor=self.params.speed_factor, max_ratio=self.params.speed_ratio), AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, hop_length=self.params.hop_length), MelToImage()]) mel_transforms = Compose([ # Audio to Mel Transformations AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, hop_length=self.params.hop_length), MelToImage()]) # Data Augmentations aug_transforms = Compose([ RandomApply([ NoiseInjection(self.params.noise_ratio), LoudnessManipulator(self.params.loudness_ratio), ShiftTime(self.params.shift_ratio), MaskAug(self.params.mask_ratio), ], p=0.6), # Utility NormalizeLocal(), ToTensor() ]) val_transforms = Compose([NormalizeLocal(), ToTensor()]) # Datasets from datasets.binar_masks import BinaryMasksDataset dataset = Namespace( **dict( # TRAIN DATASET train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup, mel_transforms=mel_transforms_train, transforms=aug_transforms), # VALIDATION DATASET val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mel_transforms=mel_transforms, transforms=val_transforms), val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, mel_transforms=mel_transforms, transforms=val_transforms), # TEST DATASET test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, mel_transforms=mel_transforms, transforms=val_transforms), ) ) return dataset class BaseDataloadersMixin(ABC): # Dataloaders # ================================================================================ # Train Dataloader def train_dataloader(self): assert isinstance(self, LightningBaseModule) # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) sampler = None return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler, batch_size=self.params.batch_size, num_workers=self.params.worker) # Test Dataloader def test_dataloader(self): assert isinstance(self, LightningBaseModule) return DataLoader(dataset=self.dataset.test_dataset, shuffle=False, batch_size=self.params.batch_size, num_workers=self.params.worker) # Validation Dataloader def val_dataloader(self): assert isinstance(self, LightningBaseModule) val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True, batch_size=self.params.batch_size, num_workers=self.params.worker) train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, batch_size=self.params.batch_size, shuffle=False) return [val_dataloader, train_dataloader]