diff --git a/_paramters.py b/_paramters.py index 8d84f6d..c5d4785 100644 --- a/_paramters.py +++ b/_paramters.py @@ -32,7 +32,7 @@ main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, hel # Transformation Parameters main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 -main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3 +main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3 @@ -40,7 +40,6 @@ main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help= # Model Parameters main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="") -main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") @@ -55,6 +54,7 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="") diff --git a/datasets/binar_masks.py b/datasets/binar_masks.py index 38477db..9d147d9 100644 --- a/datasets/binar_masks.py +++ b/datasets/binar_masks.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset import torch import variables as V -from ml_lib.modules.utils import F_x +from ml_lib.modules.util import F_x class BinaryMasksDataset(Dataset): diff --git a/main.py b/main.py index 3d17f8f..1ca2d99 100644 --- a/main.py +++ b/main.py @@ -2,14 +2,13 @@ # ============================================================================= from pathlib import Path -from tqdm import tqdm import warnings import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from ml_lib.modules.utils import LightningBaseModule +from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.logging import Logger # Project Specific Logger SubClasses @@ -111,7 +110,7 @@ def run_lightning_loop(config_obj): inference_out = f'{parameters}_test_out.csv' from main_inference import prepare_dataloader - test_dataloader = prepare_dataloader(config) + test_dataloader = prepare_dataloader(config_obj) with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: outfile.write(f'file_name,prediction\n') diff --git a/models/bandwise_conv_classifier.py b/models/bandwise_conv_classifier.py index b54000d..854b516 100644 --- a/models/bandwise_conv_classifier.py +++ b/models/bandwise_conv_classifier.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule -from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, HorizontalMerger) +from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) diff --git a/models/bandwise_conv_multihead_classifier.py b/models/bandwise_conv_multihead_classifier.py index 7b5741f..8b3ff63 100644 --- a/models/bandwise_conv_multihead_classifier.py +++ b/models/bandwise_conv_multihead_classifier.py @@ -5,7 +5,7 @@ from torch import nn from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule -from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter) +from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) diff --git a/models/conv_classifier.py b/models/conv_classifier.py index 4fac038..f89930a 100644 --- a/models/conv_classifier.py +++ b/models/conv_classifier.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule -from ml_lib.modules.utils import LightningBaseModule +from ml_lib.modules.util import LightningBaseModule from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) diff --git a/models/ensemble.py b/models/ensemble.py index df41657..5b446b1 100644 --- a/models/ensemble.py +++ b/models/ensemble.py @@ -5,7 +5,7 @@ import torch from torch import nn from torch.nn import ModuleList -from ml_lib.modules.utils import LightningBaseModule +from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.config import Config from ml_lib.utils.model_io import SavedLightningModels from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, diff --git a/models/residual_conv_classifier.py b/models/residual_conv_classifier.py index c1465e8..51fd7fd 100644 --- a/models/residual_conv_classifier.py +++ b/models/residual_conv_classifier.py @@ -4,7 +4,7 @@ from torch import nn from torch.nn import ModuleList from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule -from ml_lib.modules.utils import LightningBaseModule +from ml_lib.modules.util import LightningBaseModule from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) diff --git a/multi_run.py b/multi_run.py new file mode 100644 index 0000000..b3cf957 --- /dev/null +++ b/multi_run.py @@ -0,0 +1,52 @@ +import shutil +import warnings + +from util.config import MConfig + +warnings.filterwarnings('ignore', category=FutureWarning) +warnings.filterwarnings('ignore', category=UserWarning) + +# Imports +# ============================================================================= + +from main import run_lightning_loop +from _paramters import main_arg_parser + + +if __name__ == '__main__': + + args = main_arg_parser.parse_args() + # Model Settings + config = MConfig().read_namespace(args) + + arg_dict = dict() + for seed in range(40, 45): + arg_dict.update(main_seed=seed) + for model in ['CC', 'BCMC', 'BCC', 'RCC']: + arg_dict.update(model_type=model) + raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, + data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0) + all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2, + data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4) + speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0, + data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0) + mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2, + data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0) + noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, + data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0) + shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, + data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0) + loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, + data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4) + + for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]: + + arg_dict.update(dicts) + config = config.update(arg_dict) + version_path = config.exp_path / config.version + if version_path.exists(): + if not (version_path / 'weights.ckpt').exists(): + shutil.rmtree(version_path) + else: + continue + run_lightning_loop(config) diff --git a/util/module_mixins.py b/util/module_mixins.py index ee93d81..ae2da15 100644 --- a/util/module_mixins.py +++ b/util/module_mixins.py @@ -15,7 +15,7 @@ from torchvision.transforms import Compose, RandomApply from ml_lib.audio_toolset.audio_augmentation import Speed from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal -from ml_lib.modules.utils import LightningBaseModule +from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.transforms import ToTensor import variables as V @@ -25,7 +25,7 @@ class BaseOptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) - opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7) + opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) if self.params.sto_weight_avg: opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) return opt @@ -181,7 +181,7 @@ class BaseDataloadersMixin(ABC): # Validation Dataloader def val_dataloader(self): assert isinstance(self, LightningBaseModule) - val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True, + val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False, batch_size=self.params.batch_size, num_workers=self.params.worker) train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,