fingerprinted now should work correctly
This commit is contained in:
@ -15,7 +15,7 @@ from torchvision.transforms import Compose, RandomApply
|
||||
from ml_lib.audio_toolset.audio_augmentation import Speed
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.audio_toolset.audio_io import AudioToMel, MelToImage, NormalizeLocal
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
|
||||
import variables as V
|
||||
@ -25,7 +25,7 @@ class BaseOptimizerMixin:
|
||||
|
||||
def configure_optimizers(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=1e-7)
|
||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
return opt
|
||||
@ -181,7 +181,7 @@ class BaseDataloadersMixin(ABC):
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.params.batch_size, num_workers=self.params.worker)
|
||||
|
||||
train_dataloader = DataLoader(self.dataset.val_train_dataset, num_workers=self.params.worker,
|
||||
|
Reference in New Issue
Block a user