BandwiseBinaryClassifier is work in progress; TODO: Shape Piping.

This commit is contained in:
Si11ium 2020-05-04 18:45:13 +02:00
parent e4f6506a4b
commit 451f78f820
7 changed files with 190 additions and 42 deletions

View File

@ -44,7 +44,7 @@ main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
# Project Parameters # Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")

View File

@ -30,7 +30,7 @@ class BinaryMasksDataset(Dataset):
self._labels = self._build_labels() self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav' self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys())) self._wav_files = list(sorted(self._labels.keys()))
self._mel_folder = self.data_root / 'raw_mel' self._mel_folder = self.data_root / 'transformed'
def _build_labels(self): def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:

View File

@ -1,7 +1,7 @@
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal
# Dataset and Dataloaders # Dataset and Dataloaders
# ============================================================================= # =============================================================================
@ -11,7 +11,7 @@ from ml_lib.utils.model_io import SavedLightningModels
from util.config import MConfig from util.config import MConfig
from util.logging import MLogger from util.logging import MLogger
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()]) transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()])
# Datasets # Datasets
from datasets.binar_masks import BinaryMasksDataset from datasets.binar_masks import BinaryMasksDataset

View File

@ -0,0 +1,97 @@
from argparse import Namespace
from torch import nn
from torch.nn import ModuleDict
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage
from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders, HorizontalSplitter, \
HorizontalMerger
from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin
class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(BandwiseBinaryClassifier, self).__init__(hparams)
# Dataset and Dataloaders
# =============================================================================
# Transforms
transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
self.dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting='train', transforms=transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting='devel', transforms=transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting='test', transforms=transforms),
)
)
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 5
# Utility Modules
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
# Modules with Parameters
modules = {f"conv_1_{band_section}":
ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
modules.update({f"conv_2_{band_section}":
ConvModule(self.conv_1.shape, self.conv_filters[1], 3, conv_stride=2,
**self.params.module_kwargs) for band_section in range(self.n_band_sections)}
)
modules.update({f"conv_3_{band_section}":
ConvModule(self.conv_2.shape, self.conv_filters[2], 3, conv_stride=2,
**self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
# Utility Modules
self.merge = HorizontalMerger(self.split.shape, self.n_band_sections)
self.conv_dict = ModuleDict(modules=modules)
self.flat = Flatten(self.conv_3.shape)
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs):
tensors = self.split(batch)
for idx, tensor in enumerate(tensors):
tensor[idx] = self.conv_dict[f"conv_1_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensor[idx] = self.conv_dict[f"conv_2_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensor[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.merge(tensors)
tensor = self.flat(tensor)
tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return tensor

View File

@ -1,41 +1,21 @@
from argparse import Namespace from argparse import Namespace
import torch
from torch import nn from torch import nn
from torch.optim import Adam
from torchvision.transforms import Compose, ToTensor from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage
from ml_lib.modules.blocks import ConvModule from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders
from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin
class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule): class BinaryClassifier(BaseModuleMixin_Dataloaders,
BaseTrainMixin,
@classmethod BaseValMixin,
def name(cls): BaseOptimizerMixin,
return cls.__name__ LightningBaseModule
):
def configure_optimizers(self):
return Adam(params=self.parameters(), lr=self.params.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
loss = self.criterion(y, batch_y)
return dict(loss=loss)
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
val_loss = self.criterion(y, batch_y)
return dict(val_loss=val_loss, batch_idx=batch_idx)
def validation_epoch_end(self, outputs):
overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
return dict(log=dict(
mean_val_loss=overall_val_loss)
)
def __init__(self, hparams): def __init__(self, hparams):
super(BinaryClassifier, self).__init__(hparams) super(BinaryClassifier, self).__init__(hparams)
@ -43,7 +23,7 @@ class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule):
# Dataset and Dataloaders # Dataset and Dataloaders
# ============================================================================= # =============================================================================
# Transforms # Transforms
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()]) transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()])
# Datasets # Datasets
from datasets.binar_masks import BinaryMasksDataset from datasets.binar_masks import BinaryMasksDataset
self.dataset = Namespace( self.dataset = Namespace(
@ -58,29 +38,42 @@ class BinaryClassifier(BaseModuleMixin_Dataloaders, LightningBaseModule):
# ============================================================================= # =============================================================================
# Additional parameters # Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
# Modules
self.conv_1 = ConvModule(self.in_shape, 32, 3, conv_stride=2, **self.params.module_kwargs)
self.conv_2 = ConvModule(self.conv_1.shape, 64, 5, conv_stride=2, **self.params.module_kwargs)
self.conv_3 = ConvModule(self.conv_2.shape, 128, 7, conv_stride=2, **self.params.module_kwargs)
self.flat = Flatten(self.conv_3.shape) # Modules with Parameters
self.full_1 = nn.Linear(self.flat.shape, 32, self.params.bias) self.conv_1 = ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs)
self.conv_1b = ConvModule(self.conv_1.shape, self.conv_filters[0], 1, conv_stride=1, **self.params.module_kwargs)
self.conv_2 = ConvModule(self.conv_1b.shape, self.conv_filters[1], 5, conv_stride=2, **self.params.module_kwargs)
self.conv_2b = ConvModule(self.conv_2.shape, self.conv_filters[1], 1, conv_stride=1, **self.params.module_kwargs)
self.conv_3 = ConvModule(self.conv_2b.shape, self.conv_filters[2], 7, conv_stride=2, **self.params.module_kwargs)
self.conv_3b = ConvModule(self.conv_3.shape, self.conv_filters[2], 1, conv_stride=1, **self.params.module_kwargs)
self.flat = Flatten(self.conv_3b.shape)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.activation = self.params.activation()
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs): def forward(self, batch, **kwargs):
tensor = self.conv_1(batch) tensor = self.conv_1(batch)
tensor = self.conv_1b(tensor)
tensor = self.conv_2(tensor) tensor = self.conv_2(tensor)
tensor = self.conv_2b(tensor)
tensor = self.conv_3(tensor) tensor = self.conv_3(tensor)
tensor = self.conv_3b(tensor)
tensor = self.flat(tensor) tensor = self.flat(tensor)
tensor = self.full_1(tensor) tensor = self.full_1(tensor)
tensor = self.activation(tensor) tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor) tensor = self.full_2(tensor)
tensor = self.activation(tensor) tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor) tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor) tensor = self.sigmoid(tensor)
return tensor return tensor

55
models/module_mixins.py Normal file
View File

@ -0,0 +1,55 @@
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch.optim import Adam
class BaseOptimizerMixin:
def configure_optimizers(self):
return Adam(params=self.parameters(), lr=self.params.lr)
class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
loss = self.criterion(y, batch_y)
return dict(loss=loss)
def training_epoch_end(self, outputs):
mean_train_loss = torch.mean(torch.stack([output['loss'] for output in outputs]))
return dict(log=dict(mean_train_loss=mean_train_loss))
class BaseValMixin:
absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
val_loss = self.criterion(y, batch_y)
absolute_error = self.absolute_loss(y, batch_y)
return dict(val_loss=val_loss, absolute_error=absolute_error, batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs):
overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
mean_absolute_error = torch.mean(torch.stack([output['absolute_error'] for output in outputs]))
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
return dict(
log=dict(mean_val_loss=overall_val_loss,
mean_absolute_error=mean_absolute_error,
uar_score=uar_score)
)

View File

@ -1,9 +1,12 @@
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
from models.binary_classifier import BinaryClassifier from models.binary_classifier import BinaryClassifier
from models.bandwise_binary_classifier import BandwiseBinaryClassifier
class MConfig(Config): class MConfig(Config):
# TODO: There should be a way to automate this.
@property @property
def _model_map(self): def _model_map(self):
return dict(BinaryClassifier=BinaryClassifier) return dict(BinaryClassifier=BinaryClassifier,
BandwiseBinaryClassifier=BandwiseBinaryClassifier)