from collections import defaultdict from abc import ABC from argparse import Namespace import sklearn import torch import numpy as np from torch import nn from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from torch.utils.data import DataLoader from torchcontrib.optim import SWA from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.transforms import ToTensor import variables as V class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) optimizer_dict = dict( # 'optimizer':optimizer, # The Optimizer # 'lr_scheduler': scheduler, # The LR scheduler frequency=1, # The frequency of the scheduler interval='epoch', # The unit of the scheduler's step size # 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler # 'monitor': 'mean_val_loss' # Metric to monitor ) optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) if self.params.sto_weight_avg: optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) optimizer_dict.update(optimizer=optimizer) if self.params.lr_warmup_steps: scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps) optimizer_dict.update(lr_scheduler=scheduler) return optimizer_dict def on_train_end(self): assert isinstance(self, LightningBaseModule) for opt in self.trainer.optimizers: if isinstance(opt, SWA): opt.swap_swa_sgd() def on_epoch_end(self): assert isinstance(self, LightningBaseModule) if self.params.opt_reset_interval: if self.current_epoch % self.params.opt_reset_interval == 0: for opt in self.trainer.optimizers: opt.state = defaultdict(dict) class BaseTrainMixin: absolute_loss = nn.L1Loss() nll_loss = nn.NLLLoss() bce_loss = nn.BCELoss() def training_step(self, batch_xy, batch_nb, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out bce_loss = self.bce_loss(y.squeeze(), batch_y) return dict(loss=bce_loss) def training_epoch_end(self, outputs): assert isinstance(self, LightningBaseModule) keys = list(outputs[0].keys()) summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} for key in summary_dict.keys(): self.log(key, summary_dict[key]) class BaseValMixin: absolute_loss = nn.L1Loss() nll_loss = nn.NLLLoss() bce_loss = nn.BCELoss() def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out val_bce_loss = self.bce_loss(y.squeeze(), batch_y) return dict(val_bce_loss=val_bce_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def validation_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict() for output_idx, output in enumerate(outputs): keys = list(output[0].keys()) ident = '' if output_idx == 0 else '_train' summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] for output in output])) for key in keys if 'loss' in key} ) # UnweightedAverageRecall y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy() y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy() y_pred = (y_pred >= 0.5).astype(np.float32) uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', sample_weight=None, zero_division='warn') uar_score = torch.as_tensor(uar_score) summary_dict.update({f'uar{ident}_score': uar_score}) for key in summary_dict.keys(): self.log(key, summary_dict[key]) class BaseTestMixin: absolute_loss = nn.L1Loss() nll_loss = nn.NLLLoss() bce_loss = nn.BCELoss() def test_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out test_bce_loss = self.bce_loss(y.squeeze(), batch_y) return dict(test_bce_loss=test_bce_loss, batch_idx=batch_idx, y=y, batch_y=batch_y) def test_epoch_end(self, outputs, *_, **__): assert isinstance(self, LightningBaseModule) summary_dict = dict() keys = list(outputs[0].keys()) summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key] for output in outputs])) for key in keys if 'loss' in key} ) # UnweightedAverageRecall y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() y_pred = (y_pred >= 0.5).astype(np.float32) uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', sample_weight=None, zero_division='warn') uar_score = torch.as_tensor(uar_score) summary_dict.update({f'uar_score': uar_score}) for key in summary_dict.keys(): self.log(key, summary_dict[key]) class DatasetMixin: def build_dataset(self): assert isinstance(self, LightningBaseModule) # Dataset # ============================================================================= # Mel Transforms mel_transforms = Compose([ # Audio to Mel Transformations AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, hop_length=self.params.hop_length), MelToImage()]) mel_transforms_train = Compose([ # Audio to Mel Transformations Speed(max_amount=self.params.speed_amount, speed_min=self.params.speed_min, speed_max=self.params.speed_max ), mel_transforms]) # Utility util_transforms = Compose([NormalizeLocal(), ToTensor()]) # Data Augmentations aug_transforms = Compose([ RandomApply([ NoiseInjection(self.params.noise_ratio), LoudnessManipulator(self.params.loudness_ratio), ShiftTime(self.params.shift_ratio), MaskAug(self.params.mask_ratio), ], p=0.6), util_transforms]) # Datasets dataset = Namespace( **dict( # TRAIN DATASET train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train, use_preprocessed=self.params.use_preprocessed, stretch_dataset=self.params.stretch, mel_transforms=mel_transforms_train, transforms=aug_transforms), # VALIDATION DATASET val_train_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.train, mel_transforms=mel_transforms, transforms=util_transforms), val_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.devel, mel_transforms=mel_transforms, transforms=util_transforms), # TEST DATASET test_dataset=self.dataset_class(self.params.root, setting=V.DATA_OPTIONS.test, mel_transforms=mel_transforms, transforms=util_transforms), ) ) return dataset class BaseDataloadersMixin(ABC): # Dataloaders # ================================================================================ # Train Dataloader def train_dataloader(self): assert isinstance(self, LightningBaseModule) # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset)) sampler = None return DataLoader(dataset=self.dataset.train_dataset, shuffle=True if not sampler else None, sampler=sampler, batch_size=self.params.batch_size, pin_memory=True, num_workers=self.params.worker) # Test Dataloader def test_dataloader(self): assert isinstance(self, LightningBaseModule) return DataLoader(dataset=self.dataset.test_dataset, shuffle=False, batch_size=self.params.batch_size, pin_memory=True, num_workers=self.params.worker) # Validation Dataloader def val_dataloader(self): assert isinstance(self, LightningBaseModule) val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, pin_memory=True, batch_size=self.params.batch_size, num_workers=self.params.worker) train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, pin_memory=True, batch_size=self.params.batch_size, shuffle=False) return [val_dataloader, train_dataloader]