fingerprinted now should work correctly

This commit is contained in:
Si11ium
2020-05-19 08:33:05 +02:00
parent f57e25efdc
commit c083207235
10 changed files with 65 additions and 14 deletions

View File

@ -15,7 +15,7 @@ from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_augmentation import Speed
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.transforms import ToTensor
import variables as V
@ -25,7 +25,7 @@ class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
@ -181,7 +181,7 @@ class BaseDataloadersMixin(ABC):
# Validation Dataloader
def val_dataloader(self):
assert isinstance(self, LightningBaseModule)
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.params.batch_size, num_workers=self.params.worker)
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,