LinearModule
This commit is contained in:
		| @@ -24,11 +24,14 @@ main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasks | ||||
| main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") | ||||
| main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") | ||||
| main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") | ||||
| main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") | ||||
| main_arg_parser.add_argument("--data_hop_length", type=int, default=62, help="") | ||||
| main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") | ||||
| main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") | ||||
|  | ||||
| # Transformation Parameters | ||||
| main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.08, help="") | ||||
| main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.2, help="") | ||||
| main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.2, help="") | ||||
| main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.4, help="") | ||||
| main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0.15, help="") | ||||
|  | ||||
| # Training Parameters | ||||
| @@ -36,9 +39,10 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help | ||||
| main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") | ||||
| # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? | ||||
| main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") | ||||
| main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=300, help="") | ||||
| main_arg_parser.add_argument("--train_epochs", type=int, default=600, help="") | ||||
| main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="") | ||||
| main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") | ||||
| main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="") | ||||
| main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") | ||||
|  | ||||
| # Model Parameters | ||||
| @@ -46,12 +50,12 @@ main_arg_parser.add_argument("--model_type", type=str, default="ConvClassifier", | ||||
| main_arg_parser.add_argument("--model_secondary_type", type=str, default="BandwiseConvMultiheadClassifier", help="") | ||||
| main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") | ||||
| main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") | ||||
| main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128]", help="") | ||||
| main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64, 128, 64]", help="") | ||||
| main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") | ||||
| main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="") | ||||
| main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") | ||||
| main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="") | ||||
| main_arg_parser.add_argument("--model_dropout", type=float, default=0.25, help="") | ||||
| main_arg_parser.add_argument("--model_dropout", type=float, default=0.0, help="") | ||||
|  | ||||
| # Project Parameters | ||||
| main_arg_parser.add_argument("--project_name", type=str, default=_ROOT.name, help="") | ||||
|   | ||||
							
								
								
									
										71
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										71
									
								
								main.py
									
									
									
									
									
								
							| @@ -78,41 +78,44 @@ def run_lightning_loop(config_obj): | ||||
|  | ||||
|         # Evaluate It | ||||
|         if config_obj.main.eval: | ||||
|             model.eval() | ||||
|             if torch.cuda.is_available(): | ||||
|                 model.cuda() | ||||
|             outputs = [] | ||||
|             from tqdm import tqdm | ||||
|             for idx, batch in enumerate(tqdm(model.val_dataloader())): | ||||
|                 batch_x, label = batch | ||||
|                 outputs.append( | ||||
|                     model.validation_step((batch_x.to(device='cuda' if model.on_gpu else 'cpu'), label), idx) | ||||
|                 ) | ||||
|                 summary_dict = model.validation_epoch_end(outputs) | ||||
|             print(summary_dict['log']['uar_score']) | ||||
|  | ||||
|             # trainer.test() | ||||
|             outpath = Path(config_obj.train.outpath) | ||||
|             model_type = config_obj.model.type | ||||
|             parameters = logger.name | ||||
|             version = f'version_{logger.version}' | ||||
|             inference_out = f'{parameters}_test_out.csv' | ||||
|  | ||||
|             from main_inference import prepare_dataloader | ||||
|             test_dataloader = prepare_dataloader(config) | ||||
|  | ||||
|             with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: | ||||
|                 outfile.write(f'file_name,prediction\n') | ||||
|  | ||||
|             with torch.no_grad(): | ||||
|                 model.eval() | ||||
|                 if torch.cuda.is_available(): | ||||
|                     model.cuda() | ||||
|                 outputs = [] | ||||
|                 from tqdm import tqdm | ||||
|                 for batch in tqdm(test_dataloader, total=len(test_dataloader)): | ||||
|                     batch_x, file_name = batch | ||||
|                     batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') | ||||
|                     y = model(batch_x).main_out | ||||
|                     prediction = (y.squeeze() >= 0.5).int().item() | ||||
|                     import variables as V | ||||
|                     prediction = 'clear' if prediction == V.CLEAR else 'mask' | ||||
|                     outfile.write(f'{file_name},{prediction}\n') | ||||
|                 for idx, batch in enumerate(tqdm(model.val_dataloader()[0])): | ||||
|                     batch_x, label = batch | ||||
|                     batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu') | ||||
|                     label = label.to(device='cuda' if model.on_gpu else 'cpu') | ||||
|                     outputs.append( | ||||
|                         model.validation_step((batch_x, label), idx, 1) | ||||
|                     ) | ||||
|                     summary_dict = model.validation_epoch_end([outputs]) | ||||
|                 print(summary_dict['log']['uar_score']) | ||||
|  | ||||
|                 # trainer.test() | ||||
|                 outpath = Path(config_obj.train.outpath) | ||||
|                 model_type = config_obj.model.type | ||||
|                 parameters = logger.name | ||||
|                 version = f'version_{logger.version}' | ||||
|                 inference_out = f'{parameters}_test_out.csv' | ||||
|  | ||||
|                 from main_inference import prepare_dataloader | ||||
|                 test_dataloader = prepare_dataloader(config) | ||||
|  | ||||
|                 with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: | ||||
|                     outfile.write(f'file_name,prediction\n') | ||||
|  | ||||
|                     from tqdm import tqdm | ||||
|                     for batch in tqdm(test_dataloader, total=len(test_dataloader)): | ||||
|                         batch_x, file_name = batch | ||||
|                         batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') | ||||
|                         y = model(batch_x).main_out | ||||
|                         prediction = (y.squeeze() >= 0.5).int().item() | ||||
|                         import variables as V | ||||
|                         prediction = 'clear' if prediction == V.CLEAR else 'mask' | ||||
|                         outfile.write(f'{file_name},{prediction}\n') | ||||
|     return model | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -23,7 +23,10 @@ from datasets.binar_masks import BinaryMasksDataset | ||||
|  | ||||
|  | ||||
| def prepare_dataloader(config_obj): | ||||
|     mel_transforms = Compose([AudioToMel(n_mels=config_obj.data.n_mels), MelToImage()]) | ||||
|     mel_transforms = Compose([ | ||||
|         # Audio to Mel Transformations | ||||
|         AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft, | ||||
|                    hop_length=config_obj.data.hop_length), MelToImage()]) | ||||
|     transforms = Compose([NormalizeLocal(), ToTensor()]) | ||||
|  | ||||
|     dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', | ||||
|   | ||||
| @@ -3,9 +3,9 @@ from argparse import Namespace | ||||
| from torch import nn | ||||
| from torch.nn import ModuleDict, ModuleList | ||||
|  | ||||
| from ml_lib.modules.blocks import ConvModule | ||||
| from ml_lib.modules.blocks import ConvModule, LinearModule | ||||
| from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, | ||||
|                                   HorizontalMerger) | ||||
|                                   HorizontalMerger, F_x) | ||||
| from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, | ||||
|                                 BaseDataloadersMixin) | ||||
|  | ||||
| @@ -54,15 +54,11 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction, | ||||
|  | ||||
|         self.merge = HorizontalMerger(self.conv_dict['conv_3_1'].shape, self.n_band_sections) | ||||
|  | ||||
|         self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) | ||||
|         self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.params.bias) | ||||
|         self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs) | ||||
|         self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs) | ||||
|         self.full_3 = LinearModule(self.full_2.shape, self.full_2.out_features // 2, **self.params.module_kwargs) | ||||
|  | ||||
|         self.full_out = nn.Linear(self.full_2.out_features, 1, self.params.bias) | ||||
|  | ||||
|         # Utility Modules | ||||
|         self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x | ||||
|         self.activation = self.params.activation() | ||||
|         self.sigmoid = nn.Sigmoid() | ||||
|         self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) | ||||
|  | ||||
|     def forward(self, batch, **kwargs): | ||||
|         tensors = self.split(batch) | ||||
| @@ -74,13 +70,8 @@ class BandwiseConvClassifier(BinaryMaskDatasetFunction, | ||||
|             tensors[idx] = self.conv_dict[f"conv_3_{idx}"](tensor) | ||||
|  | ||||
|         tensor = self.merge(tensors) | ||||
|         tensor = self.flat(tensor) | ||||
|         tensor = self.full_1(tensor) | ||||
|         tensor = self.activation(tensor) | ||||
|         tensor = self.dropout(tensor) | ||||
|         tensor = self.full_2(tensor) | ||||
|         tensor = self.activation(tensor) | ||||
|         tensor = self.dropout(tensor) | ||||
|         tensor = self.full_3(tensor) | ||||
|         tensor = self.full_out(tensor) | ||||
|         tensor = self.sigmoid(tensor) | ||||
|         return Namespace(main_out=tensor) | ||||
|   | ||||
| @@ -1,12 +1,10 @@ | ||||
| from argparse import Namespace | ||||
| from collections import defaultdict | ||||
|  | ||||
| import torch | ||||
| from torch import nn | ||||
| from torch.nn import ModuleDict, ModuleList | ||||
| from torchcontrib.optim import SWA | ||||
| from torch.nn import ModuleList | ||||
|  | ||||
| from ml_lib.modules.blocks import ConvModule | ||||
| from ml_lib.modules.blocks import ConvModule, LinearModule | ||||
| from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter) | ||||
| from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, | ||||
|                                 BaseDataloadersMixin) | ||||
| @@ -59,49 +57,57 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetFunction, | ||||
|         self.in_shape = self.dataset.train_dataset.sample_shape | ||||
|         self.conv_filters = self.params.filters | ||||
|         self.criterion = nn.BCELoss() | ||||
|         self.n_band_sections = 8 | ||||
|         self.n_band_sections = 4 | ||||
|         k = 3  # Base Kernel Value | ||||
|  | ||||
|         # Modules | ||||
|         # ============================================================================= | ||||
|         self.split = HorizontalSplitter(self.in_shape, self.n_band_sections) | ||||
|         self.conv_dict = ModuleDict() | ||||
|  | ||||
|         self.conv_dict.update({f"conv_1_{band_section}": | ||||
|                                    ConvModule(self.split.shape, self.conv_filters[0], 3, conv_stride=1, **self.params.module_kwargs) | ||||
|                                for band_section in range(self.n_band_sections)} | ||||
|                               ) | ||||
|         self.conv_dict.update({f"conv_2_{band_section}": | ||||
|                                    ConvModule(self.conv_dict['conv_1_1'].shape, self.conv_filters[1], 3, conv_stride=1, | ||||
|                                               **self.params.module_kwargs) for band_section in range(self.n_band_sections)} | ||||
|                               ) | ||||
|         self.conv_dict.update({f"conv_3_{band_section}": | ||||
|                                    ConvModule(self.conv_dict['conv_2_1'].shape, self.conv_filters[2], 3, conv_stride=1, | ||||
|                                               **self.params.module_kwargs) | ||||
|                                for band_section in range(self.n_band_sections)} | ||||
|                               ) | ||||
|         self.band_list = ModuleList() | ||||
|         for band in range(self.n_band_sections): | ||||
|             last_shape = self.split.shape | ||||
|             conv_list = ModuleList() | ||||
|             for filters in self.conv_filters: | ||||
|                 conv_list.append(ConvModule(last_shape, filters, (k, k*4), conv_stride=(1, 2), | ||||
|                                             **self.params.module_kwargs)) | ||||
|                 last_shape = conv_list[-1].shape | ||||
|                 # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) | ||||
|                 # last_shape = self.conv_list[-1].shape | ||||
|             self.band_list.append(conv_list) | ||||
|  | ||||
|         self.flat = Flatten(self.conv_dict['conv_3_1'].shape) | ||||
|         self.bandwise_deep_list_1 = ModuleList([ | ||||
|             LinearModule(self.band_list[0][-1].shape, self.params.lat_dim * 4, **self.params.module_kwargs) | ||||
|             for _ in range(self.n_band_sections)]) | ||||
|         self.bandwise_deep_list_2 = ModuleList([ | ||||
|             LinearModule(self.params.lat_dim * 4, self.params.lat_dim * 2, **self.params.module_kwargs) | ||||
|             for _ in range(self.n_band_sections)]) | ||||
|         self.bandwise_latent_list = ModuleList([ | ||||
|             nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) for _ in range(self.n_band_sections)]) | ||||
|         self.bandwise_classifier_list = ModuleList([nn.Linear(self.params.lat_dim, 1, self.params.bias) | ||||
|                                                     for _ in range(self.n_band_sections)]) | ||||
|             LinearModule(self.params.lat_dim * 2, self.params.lat_dim, **self.params.module_kwargs) | ||||
|             for _ in range(self.n_band_sections)]) | ||||
|         self.bandwise_classifier_list = ModuleList([ | ||||
|             LinearModule(self.params.lat_dim, 1, bias=self.params.bias, activation=nn.Sigmoid) | ||||
|             for _ in range(self.n_band_sections)]) | ||||
|  | ||||
|         self.full_out = nn.Linear(self.n_band_sections, 1, self.params.bias) | ||||
|  | ||||
|         # Utility Modules | ||||
|         self.sigmoid = nn.Sigmoid() | ||||
|         self.full_1 = LinearModule(self.n_band_sections, self.params.lat_dim * 4, **self.params.module_kwargs) | ||||
|         self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim * 2, **self.params.module_kwargs) | ||||
|         self.full_3 = LinearModule(self.full_2.shape, self.params.lat_dim, **self.params.module_kwargs) | ||||
|         self.full_out = LinearModule(self.full_3.shape, 1, bias=self.params.bias, activation=nn.Sigmoid) | ||||
|  | ||||
|     def forward(self, batch, **kwargs): | ||||
|         tensors = self.split(batch) | ||||
|         for idx, tensor in enumerate(tensors): | ||||
|             tensor = self.conv_dict[f"conv_1_{idx}"](tensor) | ||||
|             tensor = self.conv_dict[f"conv_2_{idx}"](tensor) | ||||
|             tensor = self.conv_dict[f"conv_3_{idx}"](tensor) | ||||
|             tensor = self.flat(tensor) | ||||
|         for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)): | ||||
|             for conv in convs: | ||||
|                 tensor = conv(tensor) | ||||
|  | ||||
|             tensor = self.bandwise_deep_list_1[idx](tensor) | ||||
|             tensor = self.bandwise_deep_list_2[idx](tensor) | ||||
|             tensor = self.bandwise_latent_list[idx](tensor) | ||||
|             tensor = self.bandwise_classifier_list[idx](tensor) | ||||
|             tensors[idx] = self.sigmoid(tensor) | ||||
|             tensors[idx] = self.bandwise_classifier_list[idx](tensor) | ||||
|  | ||||
|         tensor = torch.cat(tensors, dim=1) | ||||
|         tensor = self.full_1(tensor) | ||||
|         tensor = self.full_2(tensor) | ||||
|         tensor = self.full_3(tensor) | ||||
|         tensor = self.full_out(tensor) | ||||
|         tensor = self.sigmoid(tensor) | ||||
|         return Namespace(main_out=tensor, bands=tensors) | ||||
|   | ||||
| @@ -3,8 +3,8 @@ from argparse import Namespace | ||||
| from torch import nn | ||||
| from torch.nn import ModuleList | ||||
|  | ||||
| from ml_lib.modules.blocks import ConvModule | ||||
| from ml_lib.modules.utils import LightningBaseModule, Flatten | ||||
| from ml_lib.modules.blocks import ConvModule, LinearModule | ||||
| from ml_lib.modules.utils import LightningBaseModule | ||||
| from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetFunction, | ||||
|                                 BaseDataloadersMixin) | ||||
|  | ||||
| @@ -38,38 +38,21 @@ class ConvClassifier(BinaryMaskDatasetFunction, | ||||
|         for filters in self.conv_filters: | ||||
|             self.conv_list.append(ConvModule(last_shape, filters, (k, k*2), conv_stride=2, **self.params.module_kwargs)) | ||||
|             last_shape = self.conv_list[-1].shape | ||||
|             self.conv_list.appen(ConvModule(last_shape, filters, 1, conv_stride=1, **self.params.module_kwargs)) | ||||
|             last_shape = self.conv_list[-1].shape | ||||
|             self.conv_list.appen(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) | ||||
|             last_shape = self.conv_list[-1].shape | ||||
|             k = k+2 | ||||
|             # self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs)) | ||||
|             # last_shape = self.conv_list[-1].shape | ||||
|  | ||||
|         self.flat = Flatten(self.conv_list[-1].shape) | ||||
|         self.full_1 = nn.Linear(self.flat.shape, self.params.lat_dim, self.params.bias) | ||||
|         self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias) | ||||
|         self.full_3 = nn.Linear(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias) | ||||
|         self.full_1 = LinearModule(self.flat.shape, self.params.lat_dim, **self.params.module_kwargs) | ||||
|         self.full_2 = LinearModule(self.full_1.out_features, self.full_1.out_features * 2, self.params.bias) | ||||
|         self.full_3 = LinearModule(self.full_2.out_features, self.full_2.out_features // 2, self.params.bias) | ||||
|  | ||||
|         self.full_out = nn.Linear(self.full_3.out_features, 1, self.params.bias) | ||||
|  | ||||
|         # Utility Modules | ||||
|         self.dropout = nn.Dropout2d(self.params.dropout) if self.params.dropout else lambda x: x | ||||
|         self.activation = self.params.activation() | ||||
|         self.sigmoid = nn.Sigmoid() | ||||
|         self.full_out = LinearModule(self.full_3.out_features, 1, bias=self.params.bias, activation=nn.Sigmoid) | ||||
|  | ||||
|     def forward(self, batch, **kwargs): | ||||
|         tensor = batch | ||||
|         for conv in self.conv_list: | ||||
|             tensor = conv(tensor) | ||||
|         tensor = self.flat(tensor) | ||||
|         tensor = self.full_1(tensor) | ||||
|         tensor = self.activation(tensor) | ||||
|         tensor = self.dropout(tensor) | ||||
|         tensor = self.full_2(tensor) | ||||
|         tensor = self.activation(tensor) | ||||
|         tensor = self.dropout(tensor) | ||||
|         tensor = self.full_3(tensor) | ||||
|         tensor = self.activation(tensor) | ||||
|         tensor = self.dropout(tensor) | ||||
|         tensor = self.full_out(tensor) | ||||
|         tensor = self.sigmoid(tensor) | ||||
|         return Namespace(main_out=tensor) | ||||
|   | ||||
| @@ -14,6 +14,7 @@ from torchvision.transforms import Compose, RandomApply | ||||
|  | ||||
| from ml_lib.audio_toolset.audio_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime | ||||
| from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal | ||||
| from ml_lib.modules.utils import LightningBaseModule | ||||
| from ml_lib.utils.transforms import ToTensor | ||||
|  | ||||
| import variables as V | ||||
| @@ -22,6 +23,7 @@ import variables as V | ||||
| class BaseOptimizerMixin: | ||||
|  | ||||
|     def configure_optimizers(self): | ||||
|         assert isinstance(self, LightningBaseModule) | ||||
|         opt = Adam(params=self.parameters(), lr=self.params.lr) | ||||
|         if self.params.sto_weight_avg: | ||||
|             opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) | ||||
| @@ -33,7 +35,7 @@ class BaseOptimizerMixin: | ||||
|                 opt.swap_swa_sgd() | ||||
|  | ||||
|     def on_epoch_end(self): | ||||
|         if False:  # FIXME: Pass a new parameter to model args. | ||||
|         if self.params.opt_reset_interval: | ||||
|             if self.current_epoch % self.params.opt_reset_interval == 0: | ||||
|                 for opt in self.trainer.optimizers: | ||||
|                     opt.state = defaultdict(dict) | ||||
| @@ -42,6 +44,7 @@ class BaseOptimizerMixin: | ||||
| class BaseTrainMixin: | ||||
|  | ||||
|     def training_step(self, batch_xy, batch_nb, *args, **kwargs): | ||||
|         assert isinstance(self, LightningBaseModule) | ||||
|         batch_x, batch_y = batch_xy | ||||
|         y = self(batch_x).main_out | ||||
|         loss = self.criterion(y, batch_y) | ||||
| @@ -60,7 +63,7 @@ class BaseValMixin: | ||||
|  | ||||
|     absolute_loss = L1Loss() | ||||
|  | ||||
|     def validation_step(self, batch_xy, batch_idx, *args, **kwargs): | ||||
|     def validation_step(self, batch_xy, batch_idx, dataloader_idx, *args, **kwargs): | ||||
|         batch_x, batch_y = batch_xy | ||||
|         y = self(batch_x).main_out | ||||
|         val_bce_loss = self.criterion(y, batch_y) | ||||
| @@ -69,52 +72,63 @@ class BaseValMixin: | ||||
|                     batch_idx=batch_idx, y=y, batch_y=batch_y | ||||
|                     ) | ||||
|  | ||||
|     def validation_epoch_end(self, outputs): | ||||
|         keys = list(outputs[0].keys()) | ||||
|     def validation_epoch_end(self, outputs, *args, **kwargs): | ||||
|         summary_dict = dict(log=dict()) | ||||
|         for output_idx, output in enumerate(outputs): | ||||
|             keys = list(output[0].keys()) | ||||
|             ident = '' if output_idx == 0 else '_train' | ||||
|             summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key] | ||||
|                                                                                       for output in output])) | ||||
|                                         for key in keys if 'loss' in key} | ||||
|                                        ) | ||||
|  | ||||
|         summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key] | ||||
|                                                                         for output in outputs])) | ||||
|                                  for key in keys if 'loss' in key}) | ||||
|             # UnweightedAverageRecall | ||||
|             y_true = torch.cat([output['batch_y'] for output in output]) .cpu().numpy() | ||||
|             y_pred = torch.cat([output['y'] for output in output]).squeeze().cpu().numpy() | ||||
|  | ||||
|         # UnweightedAverageRecall | ||||
|         y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy() | ||||
|         y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy() | ||||
|             y_pred = (y_pred >= 0.5).astype(np.float32) | ||||
|  | ||||
|         y_pred = (y_pred >= 0.5).astype(np.float32) | ||||
|             uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', | ||||
|                                                      sample_weight=None, zero_division='warn') | ||||
|  | ||||
|         uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', | ||||
|                                                  sample_weight=None, zero_division='warn') | ||||
|         summary_dict['log'].update(uar_score=uar_score) | ||||
|             summary_dict['log'].update({f'uar{ident}_score': uar_score}) | ||||
|         return summary_dict | ||||
|  | ||||
|  | ||||
| class BinaryMaskDatasetFunction: | ||||
|  | ||||
|     def build_dataset(self): | ||||
|         assert isinstance(self, LightningBaseModule) | ||||
|  | ||||
|         # Dataset | ||||
|         # ============================================================================= | ||||
|         # Mel Transforms | ||||
|         mel_transforms = Compose([ | ||||
|             # Audio to Mel Transformations | ||||
|             AudioToMel(n_mels=self.params.n_mels), MelToImage()]) | ||||
|             AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, | ||||
|                        hop_length=self.params.hop_length), MelToImage()]) | ||||
|         # Data Augmentations | ||||
|         aug_transforms = Compose([ | ||||
|             RandomApply([ | ||||
|             NoiseInjection(self.params.noise_ratio), | ||||
|             LoudnessManipulator(self.params.loudness_ratio), | ||||
|             ShiftTime(self.params.shift_ratio)], p=0.5), | ||||
|                 NoiseInjection(self.params.noise_ratio), | ||||
|                 LoudnessManipulator(self.params.loudness_ratio), | ||||
|                 ShiftTime(self.params.shift_ratio)], p=0.5), | ||||
|             # Utility | ||||
|             NormalizeLocal(), ToTensor() | ||||
|         ]) | ||||
|         val_transforms = Compose([NormalizeLocal(), ToTensor()]) | ||||
|  | ||||
|         # sampler = RandomSampler(train_dataset, True, len(train_dataset)) if params['bootstrap'] else None | ||||
|  | ||||
|         # Datasets | ||||
|         from datasets.binar_masks import BinaryMasksDataset | ||||
|         dataset = Namespace( | ||||
|             **dict( | ||||
|                 train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, mixup=self.params.mixup, | ||||
|                 train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, | ||||
|                                                  mixup=self.params.mixup, | ||||
|                                                  mel_transforms=mel_transforms, transforms=aug_transforms), | ||||
|                 val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, | ||||
|                                                      mel_transforms=mel_transforms, transforms=val_transforms), | ||||
|                 val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, | ||||
|                                                mel_transforms=mel_transforms, transforms=val_transforms), | ||||
|                 test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, | ||||
| @@ -142,6 +156,9 @@ class BaseDataloadersMixin(ABC): | ||||
|  | ||||
|     # Validation Dataloader | ||||
|     def val_dataloader(self): | ||||
|         return DataLoader(dataset=self.dataset.val_dataset, shuffle=True, | ||||
|                           batch_size=self.params.batch_size, | ||||
|                           num_workers=self.params.worker) | ||||
|         val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True, | ||||
|                                     batch_size=self.params.batch_size, num_workers=self.params.worker) | ||||
|  | ||||
|         train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, | ||||
|                                       batch_size=self.params.batch_size, shuffle=False) | ||||
|         return [val_dataloader, train_dataloader] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Si11ium
					Si11ium