From 451f78f820301ddf5771c8a4dd5c019500e670ec Mon Sep 17 00:00:00 2001
From: Si11ium <steffen.illium@ifi.lmu.de>
Date: Mon, 4 May 2020 18:45:13 +0200
Subject: [PATCH] BandwiseBinaryClassifier is work in progress; TODO: Shape
 Piping.

---
 _paramters.py                        |  2 +-
 datasets/binar_masks.py              |  2 +-
 main_inference.py                    |  4 +-
 models/bandwise_binary_classifier.py | 97 ++++++++++++++++++++++++++++
 models/binary_classifier.py          | 67 +++++++++----------
 models/module_mixins.py              | 55 ++++++++++++++++
 util/config.py                       |  5 +-
 7 files changed, 190 insertions(+), 42 deletions(-)
 create mode 100644 models/bandwise_binary_classifier.py
 create mode 100644 models/module_mixins.py

diff --git a/_paramters.py b/_paramters.py
index 16c2852..b42846d 100644
--- a/_paramters.py
+++ b/_paramters.py
@@ -44,7 +44,7 @@ main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
 main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
 main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
 main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="")
-main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
+main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
 
 # Project Parameters
 main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py
index 8159ff2..1e77f3d 100644
--- a/datasets/binar_masks.py
+++ b/datasets/binar_masks.py
@@ -30,7 +30,7 @@ class BinaryMasksDataset(Dataset):
         self._labels = self._build_labels()
         self._wav_folder = self.data_root / 'wav'
         self._wav_files = list(sorted(self._labels.keys()))
-        self._mel_folder = self.data_root / 'raw_mel'
+        self._mel_folder = self.data_root / 'transformed'
 
     def _build_labels(self):
         with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
diff --git a/main_inference.py b/main_inference.py
index 65ef2c2..e617b74 100644
--- a/main_inference.py
+++ b/main_inference.py
@@ -1,7 +1,7 @@
 from torch.utils.data import DataLoader, Dataset
 from torchvision.transforms import Compose, ToTensor
 
-from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal
+from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal
 
 # Dataset and Dataloaders
 # =============================================================================
@@ -11,7 +11,7 @@ from ml_lib.utils.model_io import SavedLightningModels
 from util.config import MConfig
 from util.logging import MLogger
 
-transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
+transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()])
 
 # Datasets
 from datasets.binar_masks import BinaryMasksDataset
diff --git a/models/bandwise_binary_classifier.py b/models/bandwise_binary_classifier.py
new file mode 100644
index 0000000..89a2e83
--- /dev/null
+++ b/models/bandwise_binary_classifier.py
@@ -0,0 +1,97 @@
+from argparse import Namespace
+
+from torch import nn
+from torch.nn import ModuleDict
+
+from torchvision.transforms import Compose, ToTensor
+
+from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage
+from ml_lib.modules.blocks import ConvModule
+from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders, HorizontalSplitter, \
+    HorizontalMerger
+from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin
+
+
+class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
+                       BaseTrainMixin,
+                       BaseValMixin,
+                       BaseOptimizerMixin,
+                       LightningBaseModule
+                       ):
+
+    def __init__(self, hparams):
+        super(BandwiseBinaryClassifier, self).__init__(hparams)
+
+        # Dataset and Dataloaders
+        # =============================================================================
+        # Transforms
+        transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()])
+        # Datasets
+        from datasets.binar_masks import BinaryMasksDataset
+        self.dataset = Namespace(
+            **dict(
+                train_dataset=BinaryMasksDataset(self.params.root, setting='train', transforms=transforms),
+                val_dataset=BinaryMasksDataset(self.params.root, setting='devel', transforms=transforms),
+                test_dataset=BinaryMasksDataset(self.params.root, setting='test', transforms=transforms),
+            )
+        )
+
+        # Model Paramters
+        # =============================================================================
+        # Additional parameters
+        self.in_shape = self.dataset.train_dataset.sample_shape
+        self.conv_filters = self.params.filters
+        self.criterion = nn.BCELoss()
+        self.n_band_sections = 5
+
+        # Utility Modules
+        self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
+
+        # Modules with Parameters
+        modules = {f"conv_1_{band_section}":
+                       ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs)
+                   for band_section in range(self.n_band_sections)}
+
+        modules.update({f"conv_2_{band_section}":
+                            ConvModule(self.conv_1.shape, self.conv_filters[1], 3, conv_stride=2,
+                                       **self.params.module_kwargs) for band_section in range(self.n_band_sections)}
+                       )
+        modules.update({f"conv_3_{band_section}":
+                            ConvModule(self.conv_2.shape, self.conv_filters[2], 3, conv_stride=2,
+                                       **self.params.module_kwargs)
+                        for band_section in range(self.n_band_sections)}
+                       )
+
+        self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
+        self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
+
+        self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
+
+        # Utility Modules
+        self.merge = HorizontalMerger(self.split.shape, self.n_band_sections)
+        self.conv_dict = ModuleDict(modules=modules)
+        self.flat = Flatten(self.conv_3.shape)
+        self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
+        self.activation = self.params.activation()
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, batch, **kwargs):
+        tensors = self.split(batch)
+        for idx, tensor in enumerate(tensors):
+            tensor[idx] = self.conv_dict[f"conv_1_{idx}"](tensor)
+        for idx, tensor in enumerate(tensors):
+            tensor[idx] = self.conv_dict[f"conv_2_{idx}"](tensor)
+        for idx, tensor in enumerate(tensors):
+            tensor[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
+
+        tensor = self.merge(tensors)
+        tensor = self.flat(tensor)
+        tensor = self.full_1(tensor)
+        tensor = self.activation(tensor)
+        tensor = self.dropout(tensor)
+        tensor = self.full_2(tensor)
+        tensor = self.activation(tensor)
+        tensor = self.dropout(tensor)
+        tensor = self.full_out(tensor)
+        tensor = self.sigmoid(tensor)
+        return tensor
diff --git a/models/binary_classifier.py b/models/binary_classifier.py
index 0d663e1..cbc44fa 100644
--- a/models/binary_classifier.py
+++ b/models/binary_classifier.py
@@ -1,41 +1,21 @@
 from argparse import Namespace
 
-import torch
 from torch import nn
-from torch.optim import Adam
+
 from torchvision.transforms import Compose, ToTensor
 
-from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal
+from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage
 from ml_lib.modules.blocks import ConvModule
 from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders
+from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin
 
 
-class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule):
-
-    @classmethod
-    def name(cls):
-        return cls.__name__
-
-    def configure_optimizers(self):
-        return Adam(params=self.parameters(), lr=self.params.lr)
-
-    def training_step(self, batch_xy, batch_nb, *args, **kwargs):
-        batch_x, batch_y = batch_xy
-        y = self(batch_x)
-        loss = self.criterion(y, batch_y)
-        return dict(loss=loss)
-
-    def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
-        batch_x, batch_y = batch_xy
-        y = self(batch_x)
-        val_loss = self.criterion(y, batch_y)
-        return dict(val_loss=val_loss, batch_idx=batch_idx)
-
-    def validation_epoch_end(self, outputs):
-        overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
-        return dict(log=dict(
-            mean_val_loss=overall_val_loss)
-                    )
+class BinaryClassifier(BaseModuleMixin_Dataloaders,
+                       BaseTrainMixin,
+                       BaseValMixin,
+                       BaseOptimizerMixin,
+                       LightningBaseModule
+                       ):
 
     def __init__(self, hparams):
         super(BinaryClassifier, self).__init__(hparams)
@@ -43,7 +23,7 @@ class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule):
         # Dataset and Dataloaders
         # =============================================================================
         # Transforms
-        transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
+        transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()])
         # Datasets
         from datasets.binar_masks import BinaryMasksDataset
         self.dataset = Namespace(
@@ -58,29 +38,42 @@ class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule):
         # =============================================================================
         # Additional parameters
         self.in_shape = self.dataset.train_dataset.sample_shape
+        self.conv_filters = self.params.filters
         self.criterion = nn.BCELoss()
-        # Modules
-        self.conv_1 = ConvModule(self.in_shape, 32, 3, conv_stride=2, **self.params.module_kwargs)
-        self.conv_2 = ConvModule(self.conv_1.shape, 64, 5, conv_stride=2, **self.params.module_kwargs)
-        self.conv_3 = ConvModule(self.conv_2.shape, 128, 7, conv_stride=2, **self.params.module_kwargs)
 
-        self.flat = Flatten(self.conv_3.shape)
-        self.full_1 = nn.Linear(self.flat.shape, 32, self.params.bias)
+        # Modules with Parameters
+        self.conv_1 = ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs)
+        self.conv_1b = ConvModule(self.conv_1.shape, self.conv_filters[0], 1, conv_stride=1, **self.params.module_kwargs)
+        self.conv_2 = ConvModule(self.conv_1b.shape, self.conv_filters[1], 5, conv_stride=2, **self.params.module_kwargs)
+        self.conv_2b = ConvModule(self.conv_2.shape, self.conv_filters[1], 1, conv_stride=1, **self.params.module_kwargs)
+        self.conv_3 = ConvModule(self.conv_2b.shape, self.conv_filters[2], 7, conv_stride=2, **self.params.module_kwargs)
+        self.conv_3b = ConvModule(self.conv_3.shape, self.conv_filters[2], 1, conv_stride=1, **self.params.module_kwargs)
+
+        self.flat = Flatten(self.conv_3b.shape)
+        self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
         self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
 
-        self.activation = self.params.activation()
         self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
+
+        # Utility Modules
+        self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
+        self.activation = self.params.activation()
         self.sigmoid = nn.Sigmoid()
 
     def forward(self, batch, **kwargs):
         tensor = self.conv_1(batch)
+        tensor = self.conv_1b(tensor)
         tensor = self.conv_2(tensor)
+        tensor = self.conv_2b(tensor)
         tensor = self.conv_3(tensor)
+        tensor = self.conv_3b(tensor)
         tensor = self.flat(tensor)
         tensor = self.full_1(tensor)
         tensor = self.activation(tensor)
+        tensor = self.dropout(tensor)
         tensor = self.full_2(tensor)
         tensor = self.activation(tensor)
+        tensor = self.dropout(tensor)
         tensor = self.full_out(tensor)
         tensor = self.sigmoid(tensor)
         return tensor
diff --git a/models/module_mixins.py b/models/module_mixins.py
new file mode 100644
index 0000000..df67460
--- /dev/null
+++ b/models/module_mixins.py
@@ -0,0 +1,55 @@
+import sklearn
+import torch
+import numpy as np
+from torch.nn import L1Loss
+from torch.optim import Adam
+
+
+class BaseOptimizerMixin:
+
+    def configure_optimizers(self):
+        return Adam(params=self.parameters(), lr=self.params.lr)
+
+
+class BaseTrainMixin:
+
+    def training_step(self, batch_xy, batch_nb, *args, **kwargs):
+        batch_x, batch_y = batch_xy
+        y = self(batch_x)
+        loss = self.criterion(y, batch_y)
+        return dict(loss=loss)
+
+    def training_epoch_end(self, outputs):
+        mean_train_loss = torch.mean(torch.stack([output['loss'] for output in outputs]))
+        return dict(log=dict(mean_train_loss=mean_train_loss))
+
+
+class BaseValMixin:
+
+    absolute_loss = L1Loss()
+
+    def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
+        batch_x, batch_y = batch_xy
+        y = self(batch_x)
+        val_loss = self.criterion(y, batch_y)
+        absolute_error = self.absolute_loss(y, batch_y)
+        return dict(val_loss=val_loss, absolute_error=absolute_error, batch_idx=batch_idx, y=y, batch_y=batch_y)
+
+    def validation_epoch_end(self, outputs):
+        overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
+        mean_absolute_error = torch.mean(torch.stack([output['absolute_error'] for output in outputs]))
+
+        # UnweightedAverageRecall
+        y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
+        y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
+
+        y_pred = (y_pred >= 0.5).astype(np.float32)
+
+        uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
+                                                 sample_weight=None, zero_division='warn')
+
+        return dict(
+            log=dict(mean_val_loss=overall_val_loss,
+                     mean_absolute_error=mean_absolute_error,
+                     uar_score=uar_score)
+        )
diff --git a/util/config.py b/util/config.py
index b2b6c17..2ec6c47 100644
--- a/util/config.py
+++ b/util/config.py
@@ -1,9 +1,12 @@
 from ml_lib.utils.config import Config
 from models.binary_classifier import BinaryClassifier
+from models.bandwise_binary_classifier import BandwiseBinaryClassifier
 
 
 class MConfig(Config):
+    # TODO: There should be a way to automate this.
 
     @property
     def _model_map(self):
-        return dict(BinaryClassifier=BinaryClassifier)
+        return dict(BinaryClassifier=BinaryClassifier,
+                    BandwiseBinaryClassifier=BandwiseBinaryClassifier)