fingerprinted now should work correctly
This commit is contained in:
parent
f57e25efdc
commit
c083207235
@ -32,7 +32,7 @@ main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, hel
|
||||
|
||||
# Transformation Parameters
|
||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
|
||||
@ -40,7 +40,6 @@ main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help=
|
||||
|
||||
# Model Parameters
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
|
||||
main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="")
|
||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||
@ -55,6 +54,7 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
|
||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
||||
|
@ -8,7 +8,7 @@ from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
import variables as V
|
||||
from ml_lib.modules.utils import F_x
|
||||
from ml_lib.modules.util import F_x
|
||||
|
||||
|
||||
class BinaryMasksDataset(Dataset):
|
||||
|
5
main.py
5
main.py
@ -2,14 +2,13 @@
|
||||
# =============================================================================
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.logging import Logger
|
||||
|
||||
# Project Specific Logger SubClasses
|
||||
@ -111,7 +110,7 @@ def run_lightning_loop(config_obj):
|
||||
inference_out = f'{parameters}_test_out.csv'
|
||||
|
||||
from main_inference import prepare_dataloader
|
||||
test_dataloader = prepare_dataloader(config)
|
||||
test_dataloader = prepare_dataloader(config_obj)
|
||||
|
||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||
outfile.write(f'file_name,prediction\n')
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
|
||||
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter)
|
||||
from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
|
52
multi_run.py
Normal file
52
multi_run.py
Normal file
@ -0,0 +1,52 @@
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
from util.config import MConfig
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
# Imports
|
||||
# =============================================================================
|
||||
|
||||
from main import run_lightning_loop
|
||||
from _paramters import main_arg_parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
args = main_arg_parser.parse_args()
|
||||
# Model Settings
|
||||
config = MConfig().read_namespace(args)
|
||||
|
||||
arg_dict = dict()
|
||||
for seed in range(40, 45):
|
||||
arg_dict.update(main_seed=seed)
|
||||
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
|
||||
arg_dict.update(model_type=model)
|
||||
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2,
|
||||
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4)
|
||||
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
|
||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
|
||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0)
|
||||
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4)
|
||||
|
||||
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
|
||||
|
||||
arg_dict.update(dicts)
|
||||
config = config.update(arg_dict)
|
||||
version_path = config.exp_path / config.version
|
||||
if version_path.exists():
|
||||
if not (version_path / 'weights.ckpt').exists():
|
||||
shutil.rmtree(version_path)
|
||||
else:
|
||||
continue
|
||||
run_lightning_loop(config)
|
@ -15,7 +15,7 @@ from torchvision.transforms import Compose, RandomApply
|
||||
from ml_lib.audio_toolset.audio_augmentation import Speed
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
|
||||
import variables as V
|
||||
@ -25,7 +25,7 @@ class BaseOptimizerMixin:
|
||||
|
||||
def configure_optimizers(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
|
||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
return opt
|
||||
@ -181,7 +181,7 @@ class BaseDataloadersMixin(ABC):
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
||||
|
||||
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
||||
|
Loading…
x
Reference in New Issue
Block a user