diff --git a/_paramters.py b/_paramters.py index 0525749..f8437ac 100644 --- a/_paramters.py +++ b/_paramters.py @@ -24,11 +24,14 @@ main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasks main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") +main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") +main_arg_parser.add_argument("--data_hop_length", type=int, default=62, help="") +main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") # Transformation Parameters -main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.08, help="") -main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.2, help="") +main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.2, help="") +main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.4, help="") main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="") # Training Parameters @@ -36,9 +39,10 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=300, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") -main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") +main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") # Model Parameters @@ -46,12 +50,12 @@ main_arg_parser.add_argument("--model_type", type=str, default="ConvClassifier", main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128]", help="") +main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="") main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--model_dropout", type=float, default=0.25, help="") +main_arg_parser.add_argument("--model_dropout", type=float, default=0.0, help="") # Project Parameters main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") diff --git a/main.py b/main.py index 2a0b1b4..d429e69 100644 --- a/main.py +++ b/main.py @@ -78,41 +78,44 @@ def run_lightning_loop(config_obj): # Evaluate It if config_obj.main.eval: - model.eval() - if torch.cuda.is_available(): - model.cuda() - outputs = [] - from tqdm import tqdm - for idx, batch in enumerate(tqdm(model.val_dataloader())): - batch_x, label = batch - outputs.append( - model.validation_step((batch_x.to(device='cuda' if model.on_gpu else 'cpu'), label), idx) - ) - summary_dict = model.validation_epoch_end(outputs) - print(summary_dict['log']['uar_score']) - - # trainer.test() - outpath = Path(config_obj.train.outpath) - model_type = config_obj.model.type - parameters = logger.name - version = f'version_{logger.version}' - inference_out = f'{parameters}_test_out.csv' - - from main_inference import prepare_dataloader - test_dataloader = prepare_dataloader(config) - - with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: - outfile.write(f'file_name,prediction\n') - + with torch.no_grad(): + model.eval() + if torch.cuda.is_available(): + model.cuda() + outputs = [] from tqdm import tqdm - for batch in tqdm(test_dataloader, total=len(test_dataloader)): - batch_x, file_name = batch - batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') - y = model(batch_x).main_out - prediction = (y.squeeze() >= 0.5).int().item() - import variables as V - prediction = 'clear' if prediction == V.CLEAR else 'mask' - outfile.write(f'{file_name},{prediction}\n') + for idx, batch in enumerate(tqdm(model.val_dataloader()[0])): + batch_x, label = batch + batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') + label = label.to(device='cuda' if model.on_gpu else 'cpu') + outputs.append( + model.validation_step((batch_x, label), idx, 1) + ) + summary_dict = model.validation_epoch_end([outputs]) + print(summary_dict['log']['uar_score']) + + # trainer.test() + outpath = Path(config_obj.train.outpath) + model_type = config_obj.model.type + parameters = logger.name + version = f'version_{logger.version}' + inference_out = f'{parameters}_test_out.csv' + + from main_inference import prepare_dataloader + test_dataloader = prepare_dataloader(config) + + with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: + outfile.write(f'file_name,prediction\n') + + from tqdm import tqdm + for batch in tqdm(test_dataloader, total=len(test_dataloader)): + batch_x, file_name = batch + batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') + y = model(batch_x).main_out + prediction = (y.squeeze() >= 0.5).int().item() + import variables as V + prediction = 'clear' if prediction == V.CLEAR else 'mask' + outfile.write(f'{file_name},{prediction}\n') return model diff --git a/main_inference.py b/main_inference.py index fae2481..825082d 100644 --- a/main_inference.py +++ b/main_inference.py @@ -23,7 +23,10 @@ from datasets.binar_masks import BinaryMasksDataset def prepare_dataloader(config_obj): - mel_transforms = Compose([AudioToMel(n_mels=config_obj.data.n_mels), MelToImage()]) + mel_transforms = Compose([ + # Audio to Mel Transformations + AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, + hop_length=config_obj.data.hop_length), MelToImage()]) transforms = Compose([NormalizeLocal(), ToTensor()]) dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', diff --git a/models/bandwise_conv_classifier.py b/models/bandwise_conv_classifier.py index e08ef2d..216fc3b 100644 --- a/models/bandwise_conv_classifier.py +++ b/models/bandwise_conv_classifier.py @@ -3,9 +3,9 @@ from argparse import Namespace from torch import nn from torch.nn import ModuleDict, ModuleList -from ml_lib.modules.blocks import ConvModule +from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, - HorizontalMerger) + HorizontalMerger, F_x) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, BaseDataloadersMixin) @@ -54,15 +54,11 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction, self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections) - self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) - self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) + self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs) + self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) + self.full_3 = LinearModule(self.full_2.shape, self.full_2.out_features // 2, **self.params.module_kwargs) - self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) - - # Utility Modules - self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x - self.activation = self.params.activation() - self.sigmoid = nn.Sigmoid() + self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) def forward(self, batch, **kwargs): tensors = self.split(batch) @@ -74,13 +70,8 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction, tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor) tensor = self.merge(tensors) - tensor = self.flat(tensor) tensor = self.full_1(tensor) - tensor = self.activation(tensor) - tensor = self.dropout(tensor) tensor = self.full_2(tensor) - tensor = self.activation(tensor) - tensor = self.dropout(tensor) + tensor = self.full_3(tensor) tensor = self.full_out(tensor) - tensor = self.sigmoid(tensor) return Namespace(main_out=tensor) diff --git a/models/bandwise_conv_multihead_classifier.py b/models/bandwise_conv_multihead_classifier.py index 8b635e2..8271136 100644 --- a/models/bandwise_conv_multihead_classifier.py +++ b/models/bandwise_conv_multihead_classifier.py @@ -1,12 +1,10 @@ from argparse import Namespace -from collections import defaultdict import torch from torch import nn -from torch.nn import ModuleDict, ModuleList -from torchcontrib.optim import SWA +from torch.nn import ModuleList -from ml_lib.modules.blocks import ConvModule +from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, BaseDataloadersMixin) @@ -59,49 +57,57 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction, self.in_shape = self.dataset.train_dataset.sample_shape self.conv_filters = self.params.filters self.criterion = nn.BCELoss() - self.n_band_sections = 8 + self.n_band_sections = 4 + k = 3 # Base Kernel Value # Modules # ============================================================================= self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) - self.conv_dict = ModuleDict() - self.conv_dict.update({f"conv_1_{band_section}": - ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs) - for band_section in range(self.n_band_sections)} - ) - self.conv_dict.update({f"conv_2_{band_section}": - ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1, - **self.params.module_kwargs) for band_section in range(self.n_band_sections)} - ) - self.conv_dict.update({f"conv_3_{band_section}": - ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1, - **self.params.module_kwargs) - for band_section in range(self.n_band_sections)} - ) + self.band_list = ModuleList() + for band in range(self.n_band_sections): + last_shape = self.split.shape + conv_list = ModuleList() + for filters in self.conv_filters: + conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2), + **self.params.module_kwargs)) + last_shape = conv_list[-1].shape + # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) + # last_shape = self.conv_list[-1].shape + self.band_list.append(conv_list) - self.flat = Flatten(self.conv_dict['conv_3_1'].shape) + self.bandwise_deep_list_1 = ModuleList([ + LinearModule(self.band_list[0][-1].shape, self.params.lat_dim * 4, **self.params.module_kwargs) + for _ in range(self.n_band_sections)]) + self.bandwise_deep_list_2 = ModuleList([ + LinearModule(self.params.lat_dim * 4, self.params.lat_dim * 2, **self.params.module_kwargs) + for _ in range(self.n_band_sections)]) self.bandwise_latent_list = ModuleList([ - nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) for _ in range(self.n_band_sections)]) - self.bandwise_classifier_list = ModuleList([nn.Linear(self.params.lat_dim, 1, self.params.bias) - for _ in range(self.n_band_sections)]) + LinearModule(self.params.lat_dim * 2, self.params.lat_dim, **self.params.module_kwargs) + for _ in range(self.n_band_sections)]) + self.bandwise_classifier_list = ModuleList([ + LinearModule(self.params.lat_dim, 1, bias=self.params.bias, activation=nn.Sigmoid) + for _ in range(self.n_band_sections)]) - self.full_out = nn.Linear(self.n_band_sections, 1, self.params.bias) - - # Utility Modules - self.sigmoid = nn.Sigmoid() + self.full_1 = LinearModule(self.n_band_sections, self.params.lat_dim * 4, **self.params.module_kwargs) + self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim * 2, **self.params.module_kwargs) + self.full_3 = LinearModule(self.full_2.shape, self.params.lat_dim, **self.params.module_kwargs) + self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) def forward(self, batch, **kwargs): tensors = self.split(batch) - for idx, tensor in enumerate(tensors): - tensor = self.conv_dict[f"conv_1_{idx}"](tensor) - tensor = self.conv_dict[f"conv_2_{idx}"](tensor) - tensor = self.conv_dict[f"conv_3_{idx}"](tensor) - tensor = self.flat(tensor) + for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)): + for conv in convs: + tensor = conv(tensor) + + tensor = self.bandwise_deep_list_1[idx](tensor) + tensor = self.bandwise_deep_list_2[idx](tensor) tensor = self.bandwise_latent_list[idx](tensor) - tensor = self.bandwise_classifier_list[idx](tensor) - tensors[idx] = self.sigmoid(tensor) + tensors[idx] = self.bandwise_classifier_list[idx](tensor) + tensor = torch.cat(tensors, dim=1) + tensor = self.full_1(tensor) + tensor = self.full_2(tensor) + tensor = self.full_3(tensor) tensor = self.full_out(tensor) - tensor = self.sigmoid(tensor) return Namespace(main_out=tensor, bands=tensors) diff --git a/models/conv_classifier.py b/models/conv_classifier.py index dac894d..f16f799 100644 --- a/models/conv_classifier.py +++ b/models/conv_classifier.py @@ -3,8 +3,8 @@ from argparse import Namespace from torch import nn from torch.nn import ModuleList -from ml_lib.modules.blocks import ConvModule -from ml_lib.modules.utils import LightningBaseModule, Flatten +from ml_lib.modules.blocks import ConvModule, LinearModule +from ml_lib.modules.utils import LightningBaseModule from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, BaseDataloadersMixin) @@ -38,38 +38,21 @@ class ConvClassifier(BinaryMaskDatasetFunction, for filters in self.conv_filters: self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs)) last_shape = self.conv_list[-1].shape - self.conv_list.appen(ConvModule(last_shape, filters, 1, conv_stride=1, **self.params.module_kwargs)) - last_shape = self.conv_list[-1].shape - self.conv_list.appen(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) - last_shape = self.conv_list[-1].shape - k = k+2 + # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) + # last_shape = self.conv_list[-1].shape - self.flat = Flatten(self.conv_list[-1].shape) - self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) - self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias) - self.full_3 = nn.Linear(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias) + self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs) + self.full_2 = LinearModule(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias) + self.full_3 = LinearModule(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias) - self.full_out = nn.Linear(self.full_3.out_features, 1, self.params.bias) - - # Utility Modules - self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x - self.activation = self.params.activation() - self.sigmoid = nn.Sigmoid() + self.full_out = LinearModule(self.full_3.out_features, 1, bias=self.params.bias, activation=nn.Sigmoid) def forward(self, batch, **kwargs): tensor = batch for conv in self.conv_list: tensor = conv(tensor) - tensor = self.flat(tensor) tensor = self.full_1(tensor) - tensor = self.activation(tensor) - tensor = self.dropout(tensor) tensor = self.full_2(tensor) - tensor = self.activation(tensor) - tensor = self.dropout(tensor) tensor = self.full_3(tensor) - tensor = self.activation(tensor) - tensor = self.dropout(tensor) tensor = self.full_out(tensor) - tensor = self.sigmoid(tensor) return Namespace(main_out=tensor) diff --git a/util/module_mixins.py b/util/module_mixins.py index 3897b05..9854720 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -14,6 +14,7 @@ from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal +from ml_lib.modules.utils import LightningBaseModule from ml_lib.utils.transforms import ToTensor import variables as V @@ -22,6 +23,7 @@ import variables as V class BaseOptimizerMixin: def configure_optimizers(self): + assert isinstance(self, LightningBaseModule) opt = Adam(params=self.parameters(), lr=self.params.lr) if self.params.sto_weight_avg: opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) @@ -33,7 +35,7 @@ class BaseOptimizerMixin: opt.swap_swa_sgd() def on_epoch_end(self): - if False: # FIXME: Pass a new parameter to model args. + if self.params.opt_reset_interval: if self.current_epoch % self.params.opt_reset_interval == 0: for opt in self.trainer.optimizers: opt.state = defaultdict(dict) @@ -42,6 +44,7 @@ class BaseOptimizerMixin: class BaseTrainMixin: def training_step(self, batch_xy, batch_nb, *args, **kwargs): + assert isinstance(self, LightningBaseModule) batch_x, batch_y = batch_xy y = self(batch_x).main_out loss = self.criterion(y, batch_y) @@ -60,7 +63,7 @@ class BaseValMixin: absolute_loss = L1Loss() - def validation_step(self, batch_xy, batch_idx, *args, **kwargs): + def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): batch_x, batch_y = batch_xy y = self(batch_x).main_out val_bce_loss = self.criterion(y, batch_y) @@ -69,52 +72,63 @@ class BaseValMixin: batch_idx=batch_idx, y=y, batch_y=batch_y ) - def validation_epoch_end(self, outputs): - keys = list(outputs[0].keys()) + def validation_epoch_end(self, outputs, *args, **kwargs): + summary_dict = dict(log=dict()) + for output_idx, output in enumerate(outputs): + keys = list(output[0].keys()) + ident = '' if output_idx == 0 else '_train' + summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] + for output in output])) + for key in keys if 'loss' in key} + ) - summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] - for output in outputs])) - for key in keys if 'loss' in key}) + # UnweightedAverageRecall + y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy() + y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy() - # UnweightedAverageRecall - y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() - y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() + y_pred = (y_pred >= 0.5).astype(np.float32) - y_pred = (y_pred >= 0.5).astype(np.float32) + uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', + sample_weight=None, zero_division='warn') - uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', - sample_weight=None, zero_division='warn') - summary_dict['log'].update(uar_score=uar_score) + summary_dict['log'].update({f'uar{ident}_score': uar_score}) return summary_dict class BinaryMaskDatasetFunction: def build_dataset(self): + assert isinstance(self, LightningBaseModule) # Dataset # ============================================================================= # Mel Transforms mel_transforms = Compose([ # Audio to Mel Transformations - AudioToMel(n_mels=self.params.n_mels), MelToImage()]) + AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, + hop_length=self.params.hop_length), MelToImage()]) # Data Augmentations aug_transforms = Compose([ RandomApply([ - NoiseInjection(self.params.noise_ratio), - LoudnessManipulator(self.params.loudness_ratio), - ShiftTime(self.params.shift_ratio)], p=0.5), + NoiseInjection(self.params.noise_ratio), + LoudnessManipulator(self.params.loudness_ratio), + ShiftTime(self.params.shift_ratio)], p=0.5), # Utility NormalizeLocal(), ToTensor() ]) val_transforms = Compose([NormalizeLocal(), ToTensor()]) + # sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None + # Datasets from datasets.binar_masks import BinaryMasksDataset dataset = Namespace( **dict( - train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup, + train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, + mixup=self.params.mixup, mel_transforms=mel_transforms, transforms=aug_transforms), + val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, + mel_transforms=mel_transforms, transforms=val_transforms), val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, mel_transforms=mel_transforms, transforms=val_transforms), test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, @@ -142,6 +156,9 @@ class BaseDataloadersMixin(ABC): # Validation Dataloader def val_dataloader(self): - return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, - batch_size=self.params.batch_size, - num_workers=self.params.worker) + val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True, + batch_size=self.params.batch_size, num_workers=self.params.worker) + + train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, + batch_size=self.params.batch_size, shuffle=False) + return [val_dataloader, train_dataloader]