fingerprinted now should work correctly
This commit is contained in:
parent
f57e25efdc
commit
c083207235
@ -32,7 +32,7 @@ main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, hel
|
|||||||
|
|
||||||
# Transformation Parameters
|
# Transformation Parameters
|
||||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
|
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
||||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
|
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
|
||||||
@ -40,7 +40,6 @@ main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help=
|
|||||||
|
|
||||||
# Model Parameters
|
# Model Parameters
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
|
||||||
main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="")
|
|
||||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||||
@ -55,6 +54,7 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help
|
|||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
||||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||||
|
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
|
||||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
||||||
|
@ -8,7 +8,7 @@ from torch.utils.data import Dataset
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import variables as V
|
import variables as V
|
||||||
from ml_lib.modules.utils import F_x
|
from ml_lib.modules.util import F_x
|
||||||
|
|
||||||
|
|
||||||
class BinaryMasksDataset(Dataset):
|
class BinaryMasksDataset(Dataset):
|
||||||
|
5
main.py
5
main.py
@ -2,14 +2,13 @@
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
|
||||||
from ml_lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.logging import Logger
|
from ml_lib.utils.logging import Logger
|
||||||
|
|
||||||
# Project Specific Logger SubClasses
|
# Project Specific Logger SubClasses
|
||||||
@ -111,7 +110,7 @@ def run_lightning_loop(config_obj):
|
|||||||
inference_out = f'{parameters}_test_out.csv'
|
inference_out = f'{parameters}_test_out.csv'
|
||||||
|
|
||||||
from main_inference import prepare_dataloader
|
from main_inference import prepare_dataloader
|
||||||
test_dataloader = prepare_dataloader(config)
|
test_dataloader = prepare_dataloader(config_obj)
|
||||||
|
|
||||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||||
outfile.write(f'file_name,prediction\n')
|
outfile.write(f'file_name,prediction\n')
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn
|
|||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
|
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from torch import nn
|
|||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter)
|
from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn
|
|||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
from ml_lib.utils.model_io import SavedLightningModels
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn
|
|||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
|
||||||
from ml_lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
52
multi_run.py
Normal file
52
multi_run.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import shutil
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from util.config import MConfig
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
from main import run_lightning_loop
|
||||||
|
from _paramters import main_arg_parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
args = main_arg_parser.parse_args()
|
||||||
|
# Model Settings
|
||||||
|
config = MConfig().read_namespace(args)
|
||||||
|
|
||||||
|
arg_dict = dict()
|
||||||
|
for seed in range(40, 45):
|
||||||
|
arg_dict.update(main_seed=seed)
|
||||||
|
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
|
||||||
|
arg_dict.update(model_type=model)
|
||||||
|
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||||
|
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2,
|
||||||
|
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4)
|
||||||
|
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
|
||||||
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||||
|
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
|
||||||
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||||
|
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
|
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0)
|
||||||
|
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
|
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0)
|
||||||
|
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4)
|
||||||
|
|
||||||
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
|
||||||
|
|
||||||
|
arg_dict.update(dicts)
|
||||||
|
config = config.update(arg_dict)
|
||||||
|
version_path = config.exp_path / config.version
|
||||||
|
if version_path.exists():
|
||||||
|
if not (version_path / 'weights.ckpt').exists():
|
||||||
|
shutil.rmtree(version_path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
run_lightning_loop(config)
|
@ -15,7 +15,7 @@ from torchvision.transforms import Compose, RandomApply
|
|||||||
from ml_lib.audio_toolset.audio_augmentation import Speed
|
from ml_lib.audio_toolset.audio_augmentation import Speed
|
||||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||||
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
|
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
|
||||||
from ml_lib.modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.transforms import ToTensor
|
from ml_lib.utils.transforms import ToTensor
|
||||||
|
|
||||||
import variables as V
|
import variables as V
|
||||||
@ -25,7 +25,7 @@ class BaseOptimizerMixin:
|
|||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
|
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||||
if self.params.sto_weight_avg:
|
if self.params.sto_weight_avg:
|
||||||
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||||
return opt
|
return opt
|
||||||
@ -181,7 +181,7 @@ class BaseDataloadersMixin(ABC):
|
|||||||
# Validation Dataloader
|
# Validation Dataloader
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||||
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
||||||
|
|
||||||
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user