From 451f78f820301ddf5771c8a4dd5c019500e670ec Mon Sep 17 00:00:00 2001 From: Si11ium <steffen.illium@ifi.lmu.de> Date: Mon, 4 May 2020 18:45:13 +0200 Subject: [PATCH] BandwiseBinaryClassifier is work in progress; TODO: Shape Piping. --- _paramters.py | 2 +- datasets/binar_masks.py | 2 +- main_inference.py | 4 +- models/bandwise_binary_classifier.py | 97 ++++++++++++++++++++++++++++ models/binary_classifier.py | 67 +++++++++---------- models/module_mixins.py | 55 ++++++++++++++++ util/config.py | 5 +- 7 files changed, 190 insertions(+), 42 deletions(-) create mode 100644 models/bandwise_binary_classifier.py create mode 100644 models/module_mixins.py diff --git a/_paramters.py b/_paramters.py index 16c2852..b42846d 100644 --- a/_paramters.py +++ b/_paramters.py @@ -44,7 +44,7 @@ main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") +main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="") # Project Parameters main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index 8159ff2..1e77f3d 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -30,7 +30,7 @@ class BinaryMasksDataset(Dataset): self._labels = self._build_labels() self._wav_folder = self.data_root / 'wav' self._wav_files = list(sorted(self._labels.keys())) - self._mel_folder = self.data_root / 'raw_mel' + self._mel_folder = self.data_root / 'transformed' def _build_labels(self): with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: diff --git a/main_inference.py b/main_inference.py index 65ef2c2..e617b74 100644 --- a/main_inference.py +++ b/main_inference.py @@ -1,7 +1,7 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, ToTensor -from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal +from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal # Dataset and Dataloaders # ============================================================================= @@ -11,7 +11,7 @@ from ml_lib.utils.model_io import SavedLightningModels from util.config import MConfig from util.logging import MLogger -transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()]) +transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()]) # Datasets from datasets.binar_masks import BinaryMasksDataset diff --git a/models/bandwise_binary_classifier.py b/models/bandwise_binary_classifier.py new file mode 100644 index 0000000..89a2e83 --- /dev/null +++ b/models/bandwise_binary_classifier.py @@ -0,0 +1,97 @@ +from argparse import Namespace + +from torch import nn +from torch.nn import ModuleDict + +from torchvision.transforms import Compose, ToTensor + +from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage +from ml_lib.modules.blocks import ConvModule +from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders, HorizontalSplitter, \ + HorizontalMerger +from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin + + +class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders, + BaseTrainMixin, + BaseValMixin, + BaseOptimizerMixin, + LightningBaseModule + ): + + def __init__(self, hparams): + super(BandwiseBinaryClassifier, self).__init__(hparams) + + # Dataset and Dataloaders + # ============================================================================= + # Transforms + transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()]) + # Datasets + from datasets.binar_masks import BinaryMasksDataset + self.dataset = Namespace( + **dict( + train_dataset=BinaryMasksDataset(self.params.root, setting='train', transforms=transforms), + val_dataset=BinaryMasksDataset(self.params.root, setting='devel', transforms=transforms), + test_dataset=BinaryMasksDataset(self.params.root, setting='test', transforms=transforms), + ) + ) + + # Model Paramters + # ============================================================================= + # Additional parameters + self.in_shape = self.dataset.train_dataset.sample_shape + self.conv_filters = self.params.filters + self.criterion = nn.BCELoss() + self.n_band_sections = 5 + + # Utility Modules + self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) + + # Modules with Parameters + modules = {f"conv_1_{band_section}": + ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs) + for band_section in range(self.n_band_sections)} + + modules.update({f"conv_2_{band_section}": + ConvModule(self.conv_1.shape, self.conv_filters[1], 3, conv_stride=2, + **self.params.module_kwargs) for band_section in range(self.n_band_sections)} + ) + modules.update({f"conv_3_{band_section}": + ConvModule(self.conv_2.shape, self.conv_filters[2], 3, conv_stride=2, + **self.params.module_kwargs) + for band_section in range(self.n_band_sections)} + ) + + self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) + self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) + + self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) + + # Utility Modules + self.merge = HorizontalMerger(self.split.shape, self.n_band_sections) + self.conv_dict = ModuleDict(modules=modules) + self.flat = Flatten(self.conv_3.shape) + self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x + self.activation = self.params.activation() + self.sigmoid = nn.Sigmoid() + + def forward(self, batch, **kwargs): + tensors = self.split(batch) + for idx, tensor in enumerate(tensors): + tensor[idx] = self.conv_dict[f"conv_1_{idx}"](tensor) + for idx, tensor in enumerate(tensors): + tensor[idx] = self.conv_dict[f"conv_2_{idx}"](tensor) + for idx, tensor in enumerate(tensors): + tensor[idx] = self.conv_dict[f"conv_3_{idx}"](tensor) + + tensor = self.merge(tensors) + tensor = self.flat(tensor) + tensor = self.full_1(tensor) + tensor = self.activation(tensor) + tensor = self.dropout(tensor) + tensor = self.full_2(tensor) + tensor = self.activation(tensor) + tensor = self.dropout(tensor) + tensor = self.full_out(tensor) + tensor = self.sigmoid(tensor) + return tensor diff --git a/models/binary_classifier.py b/models/binary_classifier.py index 0d663e1..cbc44fa 100644 --- a/models/binary_classifier.py +++ b/models/binary_classifier.py @@ -1,41 +1,21 @@ from argparse import Namespace -import torch from torch import nn -from torch.optim import Adam + from torchvision.transforms import Compose, ToTensor -from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal +from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage from ml_lib.modules.blocks import ConvModule from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders +from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin -class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule): - - @classmethod - def name(cls): - return cls.__name__ - - def configure_optimizers(self): - return Adam(params=self.parameters(), lr=self.params.lr) - - def training_step(self, batch_xy, batch_nb, *args, **kwargs): - batch_x, batch_y = batch_xy - y = self(batch_x) - loss = self.criterion(y, batch_y) - return dict(loss=loss) - - def validation_step(self, batch_xy, batch_idx, *args, **kwargs): - batch_x, batch_y = batch_xy - y = self(batch_x) - val_loss = self.criterion(y, batch_y) - return dict(val_loss=val_loss, batch_idx=batch_idx) - - def validation_epoch_end(self, outputs): - overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs])) - return dict(log=dict( - mean_val_loss=overall_val_loss) - ) +class BinaryClassifier(BaseModuleMixin_Dataloaders, + BaseTrainMixin, + BaseValMixin, + BaseOptimizerMixin, + LightningBaseModule + ): def __init__(self, hparams): super(BinaryClassifier, self).__init__(hparams) @@ -43,7 +23,7 @@ class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule): # Dataset and Dataloaders # ============================================================================= # Transforms - transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()]) + transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()]) # Datasets from datasets.binar_masks import BinaryMasksDataset self.dataset = Namespace( @@ -58,29 +38,42 @@ class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule): # ============================================================================= # Additional parameters self.in_shape = self.dataset.train_dataset.sample_shape + self.conv_filters = self.params.filters self.criterion = nn.BCELoss() - # Modules - self.conv_1 = ConvModule(self.in_shape, 32, 3, conv_stride=2, **self.params.module_kwargs) - self.conv_2 = ConvModule(self.conv_1.shape, 64, 5, conv_stride=2, **self.params.module_kwargs) - self.conv_3 = ConvModule(self.conv_2.shape, 128, 7, conv_stride=2, **self.params.module_kwargs) - self.flat = Flatten(self.conv_3.shape) - self.full_1 = nn.Linear(self.flat.shape, 32, self.params.bias) + # Modules with Parameters + self.conv_1 = ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs) + self.conv_1b = ConvModule(self.conv_1.shape, self.conv_filters[0], 1, conv_stride=1, **self.params.module_kwargs) + self.conv_2 = ConvModule(self.conv_1b.shape, self.conv_filters[1], 5, conv_stride=2, **self.params.module_kwargs) + self.conv_2b = ConvModule(self.conv_2.shape, self.conv_filters[1], 1, conv_stride=1, **self.params.module_kwargs) + self.conv_3 = ConvModule(self.conv_2b.shape, self.conv_filters[2], 7, conv_stride=2, **self.params.module_kwargs) + self.conv_3b = ConvModule(self.conv_3.shape, self.conv_filters[2], 1, conv_stride=1, **self.params.module_kwargs) + + self.flat = Flatten(self.conv_3b.shape) + self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) - self.activation = self.params.activation() self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) + + # Utility Modules + self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x + self.activation = self.params.activation() self.sigmoid = nn.Sigmoid() def forward(self, batch, **kwargs): tensor = self.conv_1(batch) + tensor = self.conv_1b(tensor) tensor = self.conv_2(tensor) + tensor = self.conv_2b(tensor) tensor = self.conv_3(tensor) + tensor = self.conv_3b(tensor) tensor = self.flat(tensor) tensor = self.full_1(tensor) tensor = self.activation(tensor) + tensor = self.dropout(tensor) tensor = self.full_2(tensor) tensor = self.activation(tensor) + tensor = self.dropout(tensor) tensor = self.full_out(tensor) tensor = self.sigmoid(tensor) return tensor diff --git a/models/module_mixins.py b/models/module_mixins.py new file mode 100644 index 0000000..df67460 --- /dev/null +++ b/models/module_mixins.py @@ -0,0 +1,55 @@ +import sklearn +import torch +import numpy as np +from torch.nn import L1Loss +from torch.optim import Adam + + +class BaseOptimizerMixin: + + def configure_optimizers(self): + return Adam(params=self.parameters(), lr=self.params.lr) + + +class BaseTrainMixin: + + def training_step(self, batch_xy, batch_nb, *args, **kwargs): + batch_x, batch_y = batch_xy + y = self(batch_x) + loss = self.criterion(y, batch_y) + return dict(loss=loss) + + def training_epoch_end(self, outputs): + mean_train_loss = torch.mean(torch.stack([output['loss'] for output in outputs])) + return dict(log=dict(mean_train_loss=mean_train_loss)) + + +class BaseValMixin: + + absolute_loss = L1Loss() + + def validation_step(self, batch_xy, batch_idx, *args, **kwargs): + batch_x, batch_y = batch_xy + y = self(batch_x) + val_loss = self.criterion(y, batch_y) + absolute_error = self.absolute_loss(y, batch_y) + return dict(val_loss=val_loss, absolute_error=absolute_error, batch_idx=batch_idx, y=y, batch_y=batch_y) + + def validation_epoch_end(self, outputs): + overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs])) + mean_absolute_error = torch.mean(torch.stack([output['absolute_error'] for output in outputs])) + + # UnweightedAverageRecall + y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() + y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() + + y_pred = (y_pred >= 0.5).astype(np.float32) + + uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', + sample_weight=None, zero_division='warn') + + return dict( + log=dict(mean_val_loss=overall_val_loss, + mean_absolute_error=mean_absolute_error, + uar_score=uar_score) + ) diff --git a/util/config.py b/util/config.py index b2b6c17..2ec6c47 100644 --- a/util/config.py +++ b/util/config.py @@ -1,9 +1,12 @@ from ml_lib.utils.config import Config from models.binary_classifier import BinaryClassifier +from models.bandwise_binary_classifier import BandwiseBinaryClassifier class MConfig(Config): + # TODO: There should be a way to automate this. @property def _model_map(self): - return dict(BinaryClassifier=BinaryClassifier) + return dict(BinaryClassifier=BinaryClassifier, + BandwiseBinaryClassifier=BandwiseBinaryClassifier)