Parameter Adjustmens and Ensemble Model Implementation

This commit is contained in:
Si11ium 2020-05-08 16:30:55 +02:00
parent c2860b0aed
commit 5e6b0e598f
16 changed files with 648 additions and 313 deletions

View File

@ -14,37 +14,44 @@ main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
# Main Parameters
main_arg_parser.add_argument("--main_debug", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--main_eval", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
# Transformation Parameters
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.08, help="")
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.2, help="")
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="")
# Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="BinaryClassifier", help="")
main_arg_parser.add_argument("--model_type", type=str, default="ConvClassifier", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.25, help="")
# Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")

View File

@ -1,6 +1,7 @@
import pickle
from collections import defaultdict
from pathlib import Path
import random
import librosa as librosa
from torch.utils.data import Dataset
@ -18,19 +19,21 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, transforms=None):
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert callable(transforms) or None, f'Transforms has to be callable, but was: {type(transforms)}'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
self.setting = setting
self.mixup = mixup
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._transforms = transforms or F_x()
self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys()))
self._transformed_folder = self.data_root / 'transformed'
self._mel_folder = self.data_root / 'mel'
def _build_labels(self):
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
@ -41,23 +44,45 @@ class BinaryMasksDataset(Dataset):
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()]
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
return labeldict
def __len__(self):
return len(self._labels)
return len(self._labels) * 2 if self.mixup else len(self._labels)
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
def __getitem__(self, item):
is_mixed = item >= len(self._labels)
if is_mixed:
item = item - len(self._labels)
key = self._wav_files[item]
filename = key[:-4] + '.pik'
filename = key[:-4]
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
if not (self._transformed_folder / filename).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / self._wav_files[item])
transformed_sample = self._transforms(raw_sample)
self._transformed_folder.mkdir(exist_ok=True, parents=True)
with (self._transformed_folder / filename).open(mode='wb') as f:
pickle.dump(transformed_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._transformed_folder / filename).open(mode='rb') as f:
sample = pickle.load(f, fix_imports=True)
label = torch.as_tensor(self._labels[key], dtype=torch.float)
return sample, label
if is_mixed:
label_sec = -1
while label_sec != self._labels[key]:
key_sec = random.choice(list(self._labels.keys()))
label_sec = self._labels[key_sec]
# noinspection PyUnboundLocalVariable
filename_sec = key_sec[:-4]
mel_sample_sec = self._compute_or_retrieve(filename_sec)
mix_in_border = int(random.random() * mel_sample.shape[-1]) * random.choice([1, -1])
mel_sample[:, :mix_in_border] = mel_sample_sec[:, :mix_in_border]
transformed_samples = self._transforms(mel_sample)
if not self.setting == 'test':
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

69
main.py
View File

@ -1,6 +1,8 @@
# Imports
# =============================================================================
from pathlib import Path
from tqdm import tqdm
import warnings
import torch
@ -8,10 +10,11 @@ from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.logging import Logger
# Project Specific Config and Logger SubClasses
# Project Specific Logger SubClasses
from util.config import MConfig
from util.logging import MLogger
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -22,29 +25,25 @@ def run_lightning_loop(config_obj):
# Logging
# ================================================================================
# Logger
with MLogger(config_obj) as logger:
with Logger(config_obj) as logger:
# Callbacks
# =============================================================================
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
monitor='uar_score',
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=True, save_top_k=0,
verbose=False,
save_top_k=5,
)
# =============================================================================
# Early Stopping
# TODO: For This to work, set a validation step and End Eval and Score
early_stopping_callback = EarlyStopping(
monitor='val_loss',
min_delta=0.0,
patience=0,
monitor='uar_score',
min_delta=0.01,
patience=10,
)
# Model
# =============================================================================
# Build and Init its Weights
model: LightningBaseModule = config_obj.build_and_init_model()
# Trainer
# =============================================================================
trainer = Trainer(max_epochs=config_obj.train.epochs,
@ -61,7 +60,16 @@ def run_lightning_loop(config_obj):
early_stop_callback=None
)
# Model
# =============================================================================
# Build and Init its Weights
model: LightningBaseModule = config_obj.build_and_init_model()
# Log paramters
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log_text('n_parameters', pytorch_total_params)
# Train It
if config_obj.model.type.lower() != 'ensemble':
trainer.fit(model)
# Save the last state & all parameters
@ -70,8 +78,41 @@ def run_lightning_loop(config_obj):
# Evaluate It
if config_obj.main.eval:
trainer.test()
model.eval()
if torch.cuda.is_available():
model.cuda()
outputs = []
from tqdm import tqdm
for idx, batch in enumerate(tqdm(model.val_dataloader())):
batch_x, label = batch
outputs.append(
model.validation_step((batch_x.to(device='cuda' if model.on_gpu else 'cpu'), label), idx)
)
summary_dict = model.validation_epoch_end(outputs)
print(summary_dict['log']['uar_score'])
# trainer.test()
outpath = Path(config_obj.train.outpath)
model_type = config_obj.model.type
parameters = logger.name
version = f'version_{logger.version}'
inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader
test_dataloader = prepare_dataloader(config)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_name = batch
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out
prediction = (y.squeeze() >= 0.5).int().item()
import variables as V
prediction = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction}\n')
return model

View File

@ -1,44 +1,70 @@
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor
from pathlib import Path
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal
import torch
from tqdm import tqdm
import variables as V
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
# Dataset and Dataloaders
# =============================================================================
# Transforms
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor
from util.config import MConfig
from util.logging import MLogger
transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
def prepare_dataset(config_obj):
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', transforms=transforms)
return DataLoader(dataset=dataset,
batch_size=None,
worker=config_obj.data.worker,
shuffle=False)
def prepare_dataloader(config_obj):
mel_transforms = Compose([AudioToMel(n_mels=config_obj.data.n_mels), MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()])
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
mel_transforms=mel_transforms, transforms=transforms
)
# noinspection PyTypeChecker
return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)
def restore_logger_and_model(config_obj):
logger = MLogger(config_obj)
model = SavedLightningModels().load_checkpoint(models_root_path=logger.log_dir)
logger = Logger(config_obj)
model = SavedLightningModels.load_checkpoint(models_root_path=logger.log_dir, n=-2)
model = model.restore()
if torch.cuda.is_available():
model.cuda()
else:
model.cpu()
return model
if __name__ == '__main__':
from _paramters import main_arg_parser
outpath = Path('output')
model_type = 'BandwiseConvMultiheadClassifier'
parameters = 'BCMC_9c70168a5711c269b33701f1650adfb9/'
version = 'version_1'
config_filename = 'config.ini'
inference_out = 'manual_test_out.csv'
config = MConfig().read_argparser(main_arg_parser)
test_dataset = prepare_dataset(config)
config = MConfig()
config.read_file((outpath / model_type / parameters / version / config_filename).open('r'))
test_dataloader = prepare_dataloader(config)
loaded_model = restore_logger_and_model(config)
print("run model here and find a format to store the output")
loaded_model.eval()
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_name = batch
y = loaded_model(batch_x.unsqueeze(0).to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out
prediction = (y.squeeze() >= 0.5).int().item()
prediction = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction}\n')
print('Done')

View File

@ -0,0 +1,3 @@
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.conv_classifier import ConvClassifier

View File

@ -1,99 +0,0 @@
from argparse import Namespace
from torch import nn
from torch.nn import ModuleDict
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage
from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders, HorizontalSplitter, \
HorizontalMerger
from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin
class BandwiseBinaryClassifier(BaseModuleMixin_Dataloaders,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(BandwiseBinaryClassifier, self).__init__(hparams)
# Dataset and Dataloaders
# =============================================================================
# Transforms
transforms = Compose([AudioToMel(n_mels=32), MelToImage(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
self.dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting='train', transforms=transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting='devel', transforms=transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting='test', transforms=transforms),
)
)
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 5
# Modules
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.conv_dict = ModuleDict()
self.conv_dict.update({f"conv_1_{band_section}":
ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_2_{band_section}":
ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1,
**self.params.module_kwargs) for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_3_{band_section}":
ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1,
**self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections)
self.flat = Flatten(self.merge.shape)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs):
tensors = self.split(batch)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_1_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_2_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.merge(tensors)
tensor = self.flat(tensor)
tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return tensor

View File

@ -0,0 +1,86 @@
from argparse import Namespace
from torch import nn
from torch.nn import ModuleDict, ModuleList
from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter,
HorizontalMerger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin)
class BandwiseConvClassifier(BinaryMaskDatasetFunction,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(BandwiseConvClassifier, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 4
# Modules
# =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.conv_dict = ModuleDict()
self.conv_dict.update({f"conv_1_{band_section}":
ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_2_{band_section}":
ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1,
**self.params.module_kwargs) for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_3_{band_section}":
ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1,
**self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs):
tensors = self.split(batch)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_1_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_2_{idx}"](tensor)
for idx, tensor in enumerate(tensors):
tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.merge(tensors)
tensor = self.flat(tensor)
tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return Namespace(main_out=tensor)

View File

@ -0,0 +1,107 @@
from argparse import Namespace
from collections import defaultdict
import torch
from torch import nn
from torch.nn import ModuleDict, ModuleList
from torchcontrib.optim import SWA
from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin)
class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
y, bands_y = y.main_out, y.bands
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
return_dict = {f'band_{band_idx}_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
overall_loss = self.criterion(y, batch_y)
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
return_dict.update(loss=combined_loss, overall_loss=overall_loss)
return return_dict
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
y, bands_y = y.main_out, y.bands
bands_y_losses = [self.criterion(band_y, batch_y) for band_y in bands_y]
return_dict = {f'band_{band_idx}_val_loss': band_y for band_idx, band_y in enumerate(bands_y_losses)}
overall_loss = self.criterion(y, batch_y)
combined_loss = overall_loss + torch.stack(bands_y_losses).sum()
val_abs_loss = self.absolute_loss(y, batch_y)
return_dict.update(val_bce_loss=combined_loss, val_abs_loss=val_abs_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
return return_dict
def __init__(self, hparams):
super(BandwiseConvMultiheadClassifier, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
self.n_band_sections = 8
# Modules
# =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.conv_dict = ModuleDict()
self.conv_dict.update({f"conv_1_{band_section}":
ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_2_{band_section}":
ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1,
**self.params.module_kwargs) for band_section in range(self.n_band_sections)}
)
self.conv_dict.update({f"conv_3_{band_section}":
ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1,
**self.params.module_kwargs)
for band_section in range(self.n_band_sections)}
)
self.flat = Flatten(self.conv_dict['conv_3_1'].shape)
self.bandwise_latent_list = ModuleList([
nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) for _ in range(self.n_band_sections)])
self.bandwise_classifier_list = ModuleList([nn.Linear(self.params.lat_dim, 1, self.params.bias)
for _ in range(self.n_band_sections)])
self.full_out = nn.Linear(self.n_band_sections, 1, self.params.bias)
# Utility Modules
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs):
tensors = self.split(batch)
for idx, tensor in enumerate(tensors):
tensor = self.conv_dict[f"conv_1_{idx}"](tensor)
tensor = self.conv_dict[f"conv_2_{idx}"](tensor)
tensor = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.flat(tensor)
tensor = self.bandwise_latent_list[idx](tensor)
tensor = self.bandwise_classifier_list[idx](tensor)
tensors[idx] = self.sigmoid(tensor)
tensor = torch.cat(tensors, dim=1)
tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return Namespace(main_out=tensor, bands=tensors)

View File

@ -1,79 +0,0 @@
from argparse import Namespace
from torch import nn
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, PowerToDB, MelToImage
from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten, BaseModuleMixin_Dataloaders
from models.module_mixins import BaseOptimizerMixin, BaseTrainMixin, BaseValMixin
class BinaryClassifier(BaseModuleMixin_Dataloaders,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(BinaryClassifier, self).__init__(hparams)
# Dataset and Dataloaders
# =============================================================================
# Transforms
transforms = Compose([AudioToMel(), MelToImage(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
self.dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting='train', transforms=transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting='devel', transforms=transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting='test', transforms=transforms),
)
)
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
# Modules with Parameters
self.conv_1 = ConvModule(self.in_shape, self.conv_filters[0], 3, conv_stride=2, **self.params.module_kwargs)
self.conv_1b = ConvModule(self.conv_1.shape, self.conv_filters[0], 1, conv_stride=1, **self.params.module_kwargs)
self.conv_2 = ConvModule(self.conv_1b.shape, self.conv_filters[1], 5, conv_stride=2, **self.params.module_kwargs)
self.conv_2b = ConvModule(self.conv_2.shape, self.conv_filters[1], 1, conv_stride=1, **self.params.module_kwargs)
self.conv_3 = ConvModule(self.conv_2b.shape, self.conv_filters[2], 7, conv_stride=2, **self.params.module_kwargs)
self.conv_3b = ConvModule(self.conv_3.shape, self.conv_filters[2], 1, conv_stride=1, **self.params.module_kwargs)
self.flat = Flatten(self.conv_3b.shape)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs):
tensor = self.conv_1(batch)
tensor = self.conv_1b(tensor)
tensor = self.conv_2(tensor)
tensor = self.conv_2b(tensor)
tensor = self.conv_3(tensor)
tensor = self.conv_3b(tensor)
tensor = self.flat(tensor)
tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return tensor

75
models/conv_classifier.py Normal file
View File

@ -0,0 +1,75 @@
from argparse import Namespace
from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin)
class ConvClassifier(BinaryMaskDatasetFunction,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(ConvClassifier, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
# Modules with Parameters
self.conv_list = ModuleList()
last_shape = self.in_shape
k = 3 # Base Kernel Value
for filters in self.conv_filters:
self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
self.conv_list.appen(ConvModule(last_shape, filters, 1, conv_stride=1, **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
self.conv_list.appen(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
k = k+2
self.flat = Flatten(self.conv_list[-1].shape)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias)
self.full_3 = nn.Linear(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_3.out_features, 1, self.params.bias)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs):
tensor = batch
for conv in self.conv_list:
tensor = conv(tensor)
tensor = self.flat(tensor)
tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_3(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return Namespace(main_out=tensor)

55
models/ensemble.py Normal file
View File

@ -0,0 +1,55 @@
from argparse import Namespace
from pathlib import Path
import torch
from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.model_io import SavedLightningModels
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin)
class Ensemble(BinaryMaskDatasetFunction,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(Ensemble, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
# Model Paramters
# =============================================================================
# Additional parameters
self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters
self.criterion = nn.BCELoss()
# Pre_trained_models
out_path = Path('output') / self.params.secondary_type
# exp_paths = list(out_path.rglob(f'*{self.params.exp_fingerprint}'))
exp_paths = list(out_path.rglob('*e87b8f455ba134504b1ae17114ac2a2a'))
config_ini_files = sum([list(exp_path.rglob('config.ini')) for exp_path in exp_paths], [])
self.model_list = ModuleList()
configs = [Config() for _ in range(len(config_ini_files))]
for config, ini_file in zip(configs, config_ini_files):
config.read_file(ini_file.open('r'))
model = SavedLightningModels.load_checkpoint(models_root_path=config.exp_path / config.version).restore()
self.model_list.append(model)
def forward(self, batch, **kwargs):
ys = [model(batch).main_out for model in self.model_list]
tensor = torch.stack(ys).mean(dim=0)
return Namespace(main_out=tensor)

View File

@ -1,55 +0,0 @@
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch.optim import Adam
class BaseOptimizerMixin:
def configure_optimizers(self):
return Adam(params=self.parameters(), lr=self.params.lr)
class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
loss = self.criterion(y, batch_y)
return dict(loss=loss)
def training_epoch_end(self, outputs):
mean_train_loss = torch.mean(torch.stack([output['loss'] for output in outputs]))
return dict(log=dict(mean_train_loss=mean_train_loss))
class BaseValMixin:
absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x)
val_loss = self.criterion(y, batch_y)
absolute_error = self.absolute_loss(y, batch_y)
return dict(val_loss=val_loss, absolute_error=absolute_error, batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs):
overall_val_loss = torch.mean(torch.stack([output['val_loss'] for output in outputs]))
mean_absolute_error = torch.mean(torch.stack([output['absolute_error'] for output in outputs]))
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
return dict(
log=dict(mean_val_loss=overall_val_loss,
mean_absolute_error=mean_absolute_error,
uar_score=uar_score)
)

View File

@ -1,6 +1,8 @@
from ml_lib.utils.config import Config
from models.binary_classifier import BinaryClassifier
from models.bandwise_binary_classifier import BandwiseBinaryClassifier
from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
class MConfig(Config):
@ -8,5 +10,8 @@ class MConfig(Config):
@property
def _model_map(self):
return dict(BinaryClassifier=BinaryClassifier,
BandwiseBinaryClassifier=BandwiseBinaryClassifier)
return dict(ConvClassifier=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
)

View File

@ -1,11 +0,0 @@
from pathlib import Path
from ml_lib.utils.logging import Logger
class MLogger(Logger):
@property
def outpath(self):
# FIXME: Specify a special path
return Path(self.config.train.outpath)

147
util/module_mixins.py Normal file
View File

@ -0,0 +1,147 @@
from collections import defaultdict
from abc import ABC
from argparse import Namespace
import sklearn
import torch
import numpy as np
from torch.nn import L1Loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.utils.transforms import ToTensor
import variables as V
class BaseOptimizerMixin:
def configure_optimizers(self):
opt = Adam(params=self.parameters(), lr=self.params.lr)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
def on_train_end(self):
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
if False: # FIXME: Pass a new parameter to model args.
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
loss = self.criterion(y, batch_y)
return dict(loss=loss)
def training_epoch_end(self, outputs):
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
class BaseValMixin:
absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y)
val_abs_loss = self.absolute_loss(y, batch_y)
return dict(val_bce_loss=val_bce_loss, val_abs_loss=val_abs_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y
)
def validation_epoch_end(self, outputs):
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
# UnweightedAverageRecall
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32)
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
summary_dict['log'].update(uar_score=uar_score)
return summary_dict
class BinaryMaskDatasetFunction:
def build_dataset(self):
# Dataset
# =============================================================================
# Mel Transforms
mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(n_mels=self.params.n_mels), MelToImage()])
# Data Augmentations
aug_transforms = Compose([
RandomApply([
NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5),
# Utility
NormalizeLocal(), ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace(
**dict(
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms),
)
)
return dataset
class BaseDataloadersMixin(ABC):
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
batch_size=self.params.batch_size,
num_workers=self.params.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size,
num_workers=self.params.worker)

View File

@ -1,6 +1,8 @@
# Labels
from argparse import Namespace
CLEAR = 0
MASK = 1
# Dataset Options
DATA_OPTIONS = ['test', 'devel', 'train']
DATA_OPTIONS = Namespace(test='test', devel='devel', train='train')