2021-03-18 21:43:27 +01:00

51 lines
2.1 KiB
Python

from collections import defaultdict
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule
class OptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
optimizer_dict = dict(
# 'optimizer':optimizer, # The Optimizer
# 'lr_scheduler': scheduler, # The LR scheduler
frequency=1, # The frequency of the scheduler
interval='epoch', # The unit of the scheduler's step size
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.scheduler == CosineAnnealingWarmRestarts.__name__:
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=self.params.lr_scheduler_parameter)
elif self.params.scheduler == LambdaLR.__name__:
lr_reduce_ratio = self.params.lr_scheduler_parameter
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lr_reduce_ratio ** epoch)
else:
scheduler = None
if scheduler:
optimizer_dict.update(lr_scheduler=scheduler)
return optimizer_dict
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
for opt in self.trainer.optimizers:
if isinstance(opt, SWA):
opt.swap_swa_sgd()
def on_epoch_end(self):
assert isinstance(self, LightningBaseModule)
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)