LinearModule

This commit is contained in:
Si11ium
2020-05-09 21:56:58 +02:00
parent 5e6b0e598f
commit 3fbc98dfa3
7 changed files with 145 additions and 138 deletions
+9 -5
View File
@@ -24,11 +24,14 @@ main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasks
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=62, help="")
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
# Transformation Parameters # Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.08, help="") main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.2, help="")
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.2, help="") main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.4, help="")
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="") main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="")
# Training Parameters # Training Parameters
@@ -36,9 +39,10 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=300, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model Parameters # Model Parameters
@@ -46,12 +50,12 @@ main_arg_parser.add_argument("--model_type", type=str, default="ConvClassifier",
main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="") main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.25, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.0, help="")
# Project Parameters # Project Parameters
main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="")
+37 -34
View File
@@ -78,41 +78,44 @@ def run_lightning_loop(config_obj):
# Evaluate It # Evaluate It
if config_obj.main.eval: if config_obj.main.eval:
model.eval() with torch.no_grad():
if torch.cuda.is_available(): model.eval()
model.cuda() if torch.cuda.is_available():
outputs = [] model.cuda()
from tqdm import tqdm outputs = []
for idx, batch in enumerate(tqdm(model.val_dataloader())):
batch_x, label = batch
outputs.append(
model.validation_step((batch_x.to(device='cuda' if model.on_gpu else 'cpu'), label), idx)
)
summary_dict = model.validation_epoch_end(outputs)
print(summary_dict['log']['uar_score'])
# trainer.test()
outpath = Path(config_obj.train.outpath)
model_type = config_obj.model.type
parameters = logger.name
version = f'version_{logger.version}'
inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader
test_dataloader = prepare_dataloader(config)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
from tqdm import tqdm from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)): for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
batch_x, file_name = batch batch_x, label = batch
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out label = label.to(device='cuda' if model.on_gpu else 'cpu')
prediction = (y.squeeze() >= 0.5).int().item() outputs.append(
import variables as V model.validation_step((batch_x, label), idx, 1)
prediction = 'clear' if prediction == V.CLEAR else 'mask' )
outfile.write(f'{file_name},{prediction}\n') summary_dict = model.validation_epoch_end([outputs])
print(summary_dict['log']['uar_score'])
# trainer.test()
outpath = Path(config_obj.train.outpath)
model_type = config_obj.model.type
parameters = logger.name
version = f'version_{logger.version}'
inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader
test_dataloader = prepare_dataloader(config)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n')
from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_name = batch
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out
prediction = (y.squeeze() >= 0.5).int().item()
import variables as V
prediction = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction}\n')
return model return model
+4 -1
View File
@@ -23,7 +23,10 @@ from datasets.binar_masks import BinaryMasksDataset
def prepare_dataloader(config_obj): def prepare_dataloader(config_obj):
mel_transforms = Compose([AudioToMel(n_mels=config_obj.data.n_mels), MelToImage()]) mel_transforms = Compose([
# Audio to Mel Transformations
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length), MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()]) transforms = Compose([NormalizeLocal(), ToTensor()])
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
+7 -16
View File
@@ -3,9 +3,9 @@ from argparse import Namespace
from torch import nn from torch import nn
from torch.nn import ModuleDict, ModuleList from torch.nn import ModuleDict, ModuleList
from ml_lib.modules.blocks import ConvModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter,
HorizontalMerger) HorizontalMerger, F_x)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin) BaseDataloadersMixin)
@@ -54,15 +54,11 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction,
self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections) self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
self.full_3 = LinearModule(self.full_2.shape, self.full_2.out_features // 2, **self.params.module_kwargs)
self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs): def forward(self, batch, **kwargs):
tensors = self.split(batch) tensors = self.split(batch)
@@ -74,13 +70,8 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction,
tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor) tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.merge(tensors) tensor = self.merge(tensors)
tensor = self.flat(tensor)
tensor = self.full_1(tensor) tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor) tensor = self.full_2(tensor)
tensor = self.activation(tensor) tensor = self.full_3(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor) tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return Namespace(main_out=tensor) return Namespace(main_out=tensor)
+41 -35
View File
@@ -1,12 +1,10 @@
from argparse import Namespace from argparse import Namespace
from collections import defaultdict
import torch import torch
from torch import nn from torch import nn
from torch.nn import ModuleDict, ModuleList from torch.nn import ModuleList
from torchcontrib.optim import SWA
from ml_lib.modules.blocks import ConvModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter) from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin) BaseDataloadersMixin)
@@ -59,49 +57,57 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction,
self.in_shape = self.dataset.train_dataset.sample_shape self.in_shape = self.dataset.train_dataset.sample_shape
self.conv_filters = self.params.filters self.conv_filters = self.params.filters
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
self.n_band_sections = 8 self.n_band_sections = 4
k = 3 # Base Kernel Value
# Modules # Modules
# ============================================================================= # =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.conv_dict = ModuleDict()
self.conv_dict.update({f"conv_1_{band_section}": self.band_list = ModuleList()
ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs) for band in range(self.n_band_sections):
for band_section in range(self.n_band_sections)} last_shape = self.split.shape
) conv_list = ModuleList()
self.conv_dict.update({f"conv_2_{band_section}": for filters in self.conv_filters:
ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1, conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2),
**self.params.module_kwargs) for band_section in range(self.n_band_sections)} **self.params.module_kwargs))
) last_shape = conv_list[-1].shape
self.conv_dict.update({f"conv_3_{band_section}": # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1, # last_shape = self.conv_list[-1].shape
**self.params.module_kwargs) self.band_list.append(conv_list)
for band_section in range(self.n_band_sections)}
)
self.flat = Flatten(self.conv_dict['conv_3_1'].shape) self.bandwise_deep_list_1 = ModuleList([
LinearModule(self.band_list[0][-1].shape, self.params.lat_dim * 4, **self.params.module_kwargs)
for _ in range(self.n_band_sections)])
self.bandwise_deep_list_2 = ModuleList([
LinearModule(self.params.lat_dim * 4, self.params.lat_dim * 2, **self.params.module_kwargs)
for _ in range(self.n_band_sections)])
self.bandwise_latent_list = ModuleList([ self.bandwise_latent_list = ModuleList([
nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) for _ in range(self.n_band_sections)]) LinearModule(self.params.lat_dim * 2, self.params.lat_dim, **self.params.module_kwargs)
self.bandwise_classifier_list = ModuleList([nn.Linear(self.params.lat_dim, 1, self.params.bias) for _ in range(self.n_band_sections)])
for _ in range(self.n_band_sections)]) self.bandwise_classifier_list = ModuleList([
LinearModule(self.params.lat_dim, 1, bias=self.params.bias, activation=nn.Sigmoid)
for _ in range(self.n_band_sections)])
self.full_out = nn.Linear(self.n_band_sections, 1, self.params.bias) self.full_1 = LinearModule(self.n_band_sections, self.params.lat_dim * 4, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim * 2, **self.params.module_kwargs)
# Utility Modules self.full_3 = LinearModule(self.full_2.shape, self.params.lat_dim, **self.params.module_kwargs)
self.sigmoid = nn.Sigmoid() self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid)
def forward(self, batch, **kwargs): def forward(self, batch, **kwargs):
tensors = self.split(batch) tensors = self.split(batch)
for idx, tensor in enumerate(tensors): for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
tensor = self.conv_dict[f"conv_1_{idx}"](tensor) for conv in convs:
tensor = self.conv_dict[f"conv_2_{idx}"](tensor) tensor = conv(tensor)
tensor = self.conv_dict[f"conv_3_{idx}"](tensor)
tensor = self.flat(tensor) tensor = self.bandwise_deep_list_1[idx](tensor)
tensor = self.bandwise_deep_list_2[idx](tensor)
tensor = self.bandwise_latent_list[idx](tensor) tensor = self.bandwise_latent_list[idx](tensor)
tensor = self.bandwise_classifier_list[idx](tensor) tensors[idx] = self.bandwise_classifier_list[idx](tensor)
tensors[idx] = self.sigmoid(tensor)
tensor = torch.cat(tensors, dim=1) tensor = torch.cat(tensors, dim=1)
tensor = self.full_1(tensor)
tensor = self.full_2(tensor)
tensor = self.full_3(tensor)
tensor = self.full_out(tensor) tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return Namespace(main_out=tensor, bands=tensors) return Namespace(main_out=tensor, bands=tensors)
+8 -25
View File
@@ -3,8 +3,8 @@ from argparse import Namespace
from torch import nn from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import LightningBaseModule, Flatten from ml_lib.modules.utils import LightningBaseModule
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction,
BaseDataloadersMixin) BaseDataloadersMixin)
@@ -38,38 +38,21 @@ class ConvClassifier(BinaryMaskDatasetFunction,
for filters in self.conv_filters: for filters in self.conv_filters:
self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs)) self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape last_shape = self.conv_list[-1].shape
self.conv_list.appen(ConvModule(last_shape, filters, 1, conv_stride=1, **self.params.module_kwargs)) # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape # last_shape = self.conv_list[-1].shape
self.conv_list.appen(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
last_shape = self.conv_list[-1].shape
k = k+2
self.flat = Flatten(self.conv_list[-1].shape) self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) self.full_2 = LinearModule(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias)
self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias) self.full_3 = LinearModule(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias)
self.full_3 = nn.Linear(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias)
self.full_out = nn.Linear(self.full_3.out_features, 1, self.params.bias) self.full_out = LinearModule(self.full_3.out_features, 1, bias=self.params.bias, activation=nn.Sigmoid)
# Utility Modules
self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x
self.activation = self.params.activation()
self.sigmoid = nn.Sigmoid()
def forward(self, batch, **kwargs): def forward(self, batch, **kwargs):
tensor = batch tensor = batch
for conv in self.conv_list: for conv in self.conv_list:
tensor = conv(tensor) tensor = conv(tensor)
tensor = self.flat(tensor)
tensor = self.full_1(tensor) tensor = self.full_1(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_2(tensor) tensor = self.full_2(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_3(tensor) tensor = self.full_3(tensor)
tensor = self.activation(tensor)
tensor = self.dropout(tensor)
tensor = self.full_out(tensor) tensor = self.full_out(tensor)
tensor = self.sigmoid(tensor)
return Namespace(main_out=tensor) return Namespace(main_out=tensor)
+39 -22
View File
@@ -14,6 +14,7 @@ from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.transforms import ToTensor from ml_lib.utils.transforms import ToTensor
import variables as V import variables as V
@@ -22,6 +23,7 @@ import variables as V
class BaseOptimizerMixin: class BaseOptimizerMixin:
def configure_optimizers(self): def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr) opt = Adam(params=self.parameters(), lr=self.params.lr)
if self.params.sto_weight_avg: if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
@@ -33,7 +35,7 @@ class BaseOptimizerMixin:
opt.swap_swa_sgd() opt.swap_swa_sgd()
def on_epoch_end(self): def on_epoch_end(self):
if False: # FIXME: Pass a new parameter to model args. if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0: if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers: for opt in self.trainer.optimizers:
opt.state = defaultdict(dict) opt.state = defaultdict(dict)
@@ -42,6 +44,7 @@ class BaseOptimizerMixin:
class BaseTrainMixin: class BaseTrainMixin:
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
y = self(batch_x).main_out y = self(batch_x).main_out
loss = self.criterion(y, batch_y) loss = self.criterion(y, batch_y)
@@ -60,7 +63,7 @@ class BaseValMixin:
absolute_loss = L1Loss() absolute_loss = L1Loss()
def validation_step(self, batch_xy, batch_idx, *args, **kwargs): def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs):
batch_x, batch_y = batch_xy batch_x, batch_y = batch_xy
y = self(batch_x).main_out y = self(batch_x).main_out
val_bce_loss = self.criterion(y, batch_y) val_bce_loss = self.criterion(y, batch_y)
@@ -69,52 +72,63 @@ class BaseValMixin:
batch_idx=batch_idx, y=y, batch_y=batch_y batch_idx=batch_idx, y=y, batch_y=batch_y
) )
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs, *args, **kwargs):
keys = list(outputs[0].keys()) summary_dict = dict(log=dict())
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train'
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output]))
for key in keys if 'loss' in key}
)
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] # UnweightedAverageRecall
for output in outputs])) y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy()
for key in keys if 'loss' in key}) y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy()
# UnweightedAverageRecall y_pred = (y_pred >= 0.5).astype(np.float32)
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = (y_pred >= 0.5).astype(np.float32) uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', summary_dict['log'].update({f'uar{ident}_score': uar_score})
sample_weight=None, zero_division='warn')
summary_dict['log'].update(uar_score=uar_score)
return summary_dict return summary_dict
class BinaryMaskDatasetFunction: class BinaryMaskDatasetFunction:
def build_dataset(self): def build_dataset(self):
assert isinstance(self, LightningBaseModule)
# Dataset # Dataset
# ============================================================================= # =============================================================================
# Mel Transforms # Mel Transforms
mel_transforms = Compose([ mel_transforms = Compose([
# Audio to Mel Transformations # Audio to Mel Transformations
AudioToMel(n_mels=self.params.n_mels), MelToImage()]) AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length), MelToImage()])
# Data Augmentations # Data Augmentations
aug_transforms = Compose([ aug_transforms = Compose([
RandomApply([ RandomApply([
NoiseInjection(self.params.noise_ratio), NoiseInjection(self.params.noise_ratio),
LoudnessManipulator(self.params.loudness_ratio), LoudnessManipulator(self.params.loudness_ratio),
ShiftTime(self.params.shift_ratio)], p=0.5), ShiftTime(self.params.shift_ratio)], p=0.5),
# Utility # Utility
NormalizeLocal(), ToTensor() NormalizeLocal(), ToTensor()
]) ])
val_transforms = Compose([NormalizeLocal(), ToTensor()]) val_transforms = Compose([NormalizeLocal(), ToTensor()])
# sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None
# Datasets # Datasets
from datasets.binar_masks import BinaryMasksDataset from datasets.binar_masks import BinaryMasksDataset
dataset = Namespace( dataset = Namespace(
**dict( **dict(
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup, train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mixup=self.params.mixup,
mel_transforms=mel_transforms, transforms=aug_transforms), mel_transforms=mel_transforms, transforms=aug_transforms),
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=val_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms), mel_transforms=mel_transforms, transforms=val_transforms),
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
@@ -142,6 +156,9 @@ class BaseDataloadersMixin(ABC):
# Validation Dataloader # Validation Dataloader
def val_dataloader(self): def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.params.batch_size, batch_size=self.params.batch_size, num_workers=self.params.worker)
num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
batch_size=self.params.batch_size, shuffle=False)
return [val_dataloader, train_dataloader]