fingerprinted now should work correctly

This commit is contained in:
Si11ium 2020-05-19 08:33:05 +02:00
parent f57e25efdc
commit c083207235
10 changed files with 65 additions and 14 deletions

View File

@ -32,7 +32,7 @@ main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, hel
# Transformation Parameters # Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3 main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3 main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
@ -40,7 +40,6 @@ main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help=
# Model Parameters # Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="") main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")
main_arg_parser.add_argument("--model_secondary_type", type=str, default="RCC", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="") main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
@ -55,6 +54,7 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")

View File

@ -8,7 +8,7 @@ from torch.utils.data import Dataset
import torch import torch
import variables as V import variables as V
from ml_lib.modules.utils import F_x from ml_lib.modules.util import F_x
class BinaryMasksDataset(Dataset): class BinaryMasksDataset(Dataset):

View File

@ -2,14 +2,13 @@
# ============================================================================= # =============================================================================
from pathlib import Path from pathlib import Path
from tqdm import tqdm
import warnings import warnings
import torch import torch
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from ml_lib.modules.utils import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.logging import Logger from ml_lib.utils.logging import Logger
# Project Specific Logger SubClasses # Project Specific Logger SubClasses
@ -111,7 +110,7 @@ def run_lightning_loop(config_obj):
inference_out = f'{parameters}_test_out.csv' inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader from main_inference import prepare_dataloader
test_dataloader = prepare_dataloader(config) test_dataloader = prepare_dataloader(config_obj)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
outfile.write(f'file_name,prediction\n') outfile.write(f'file_name,prediction\n')

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import (LightningBaseModule, HorizontalSplitter, HorizontalMerger) from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)

View File

@ -5,7 +5,7 @@ from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import (LightningBaseModule, Flatten, HorizontalSplitter) from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.utils import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)

View File

@ -5,7 +5,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.utils import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
from ml_lib.utils.model_io import SavedLightningModels from ml_lib.utils.model_io import SavedLightningModels
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import ModuleList from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule from ml_lib.modules.blocks import ConvModule, LinearModule, ResidualModule
from ml_lib.modules.utils import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin) BaseDataloadersMixin)

52
multi_run.py Normal file
View File

@ -0,0 +1,52 @@
import shutil
import warnings
from util.config import MConfig
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from main import run_lightning_loop
from _paramters import main_arg_parser
if __name__ == '__main__':
args = main_arg_parser.parse_args()
# Model Settings
config = MConfig().read_namespace(args)
arg_dict = dict()
for seed in range(40, 45):
arg_dict.update(main_seed=seed)
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
arg_dict.update(model_type=model)
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2,
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4)
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0)
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0)
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0)
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4)
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
arg_dict.update(dicts)
config = config.update(arg_dict)
version_path = config.exp_path / config.version
if version_path.exists():
if not (version_path / 'weights.ckpt').exists():
shutil.rmtree(version_path)
else:
continue
run_lightning_loop(config)

View File

@ -15,7 +15,7 @@ from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor from ml_lib.utils.transforms import ToTensor
import variables as V import variables as V
@ -25,7 +25,7 @@ class BaseOptimizerMixin:
def configure_optimizers(self): def configure_optimizers(self):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7) opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg: if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05) opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt return opt
@ -181,7 +181,7 @@ class BaseDataloadersMixin(ABC):
# Validation Dataloader # Validation Dataloader
def val_dataloader(self): def val_dataloader(self):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True, val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.params.batch_size, num_workers=self.params.worker) batch_size=self.params.batch_size, num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker, train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,