51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
from collections import defaultdict
|
|
|
|
from torch.optim import Adam, AdamW
|
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
|
|
from torchcontrib.optim import SWA
|
|
|
|
from ml_lib.modules.util import LightningBaseModule
|
|
|
|
|
|
class OptimizerMixin:
|
|
|
|
def configure_optimizers(self):
|
|
assert isinstance(self, LightningBaseModule)
|
|
optimizer_dict = dict(
|
|
# 'optimizer':optimizer, # The Optimizer
|
|
# 'lr_scheduler': scheduler, # The LR scheduler
|
|
frequency=1, # The frequency of the scheduler
|
|
interval='epoch', # The unit of the scheduler's step size
|
|
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
|
|
# 'monitor': 'mean_val_loss' # Metric to monitor
|
|
)
|
|
|
|
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
|
if self.params.sto_weight_avg:
|
|
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
|
optimizer_dict.update(optimizer=optimizer)
|
|
|
|
if self.params.scheduler == CosineAnnealingWarmRestarts.__name__:
|
|
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_scheduler_parameter)
|
|
elif self.params.scheduler == LambdaLR.__name__:
|
|
lr_reduce_ratio = self.params.lr_scheduler_parameter
|
|
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lr_reduce_ratio ** epoch)
|
|
else:
|
|
scheduler = None
|
|
if scheduler:
|
|
optimizer_dict.update(lr_scheduler=scheduler)
|
|
return optimizer_dict
|
|
|
|
def on_train_end(self):
|
|
assert isinstance(self, LightningBaseModule)
|
|
for opt in self.trainer.optimizers:
|
|
if isinstance(opt, SWA):
|
|
opt.swap_swa_sgd()
|
|
|
|
def on_epoch_end(self):
|
|
assert isinstance(self, LightningBaseModule)
|
|
if self.params.opt_reset_interval:
|
|
if self.current_epoch % self.params.opt_reset_interval == 0:
|
|
for opt in self.trainer.optimizers:
|
|
opt.state = defaultdict(dict)
|