Dataset rdy
This commit is contained in:
45
util/optimizer_mixin.py
Normal file
45
util/optimizer_mixin.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
|
||||
from torchcontrib.optim import SWA
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
|
||||
|
||||
class OptimizerMixin:
|
||||
|
||||
def configure_optimizers(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
optimizer_dict = dict(
|
||||
# 'optimizer':optimizer, # The Optimizer
|
||||
# 'lr_scheduler': scheduler, # The LR scheduler
|
||||
frequency=1, # The frequency of the scheduler
|
||||
interval='epoch', # The unit of the scheduler's step size
|
||||
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
|
||||
# 'monitor': 'mean_val_loss' # Metric to monitor
|
||||
)
|
||||
|
||||
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
optimizer_dict.update(optimizer=optimizer)
|
||||
if self.params.lr_warmup_steps:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
|
||||
optimizer_dict.update(lr_scheduler=scheduler)
|
||||
return optimizer_dict
|
||||
|
||||
def on_train_end(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
for opt in self.trainer.optimizers:
|
||||
if isinstance(opt, SWA):
|
||||
opt.swap_swa_sgd()
|
||||
|
||||
def on_epoch_end(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
if self.params.opt_reset_interval:
|
||||
if self.current_epoch % self.params.opt_reset_interval == 0:
|
||||
for opt in self.trainer.optimizers:
|
||||
opt.state = defaultdict(dict)
|
||||
Reference in New Issue
Block a user