from collections import defaultdict from torch.optim import Adam, AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR from torchcontrib.optim import SWA from ml_lib.modules.util import LightningBaseModule class OptimizerMixin: def configure_optimizers(self): assert isinstance(self, LightningBaseModule) optimizer_dict = dict( # 'optimizer':optimizer, # The Optimizer # 'lr_scheduler': scheduler, # The LR scheduler frequency=1, # The frequency of the scheduler interval='epoch', # The unit of the scheduler's step size # 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler # 'monitor': 'mean_val_loss' # Metric to monitor ) optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay) if self.params.sto_weight_avg: optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05) optimizer_dict.update(optimizer=optimizer) if self.params.scheduler == CosineAnnealingWarmRestarts.__name__: scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=self.params.lr_scheduler_parameter) elif self.params.scheduler == LambdaLR.__name__: lr_reduce_ratio = self.params.lr_scheduler_parameter scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lr_reduce_ratio ** epoch) else: scheduler = None if scheduler: optimizer_dict.update(lr_scheduler=scheduler) return optimizer_dict def on_train_end(self): assert isinstance(self, LightningBaseModule) for opt in self.trainer.optimizers: if isinstance(opt, SWA): opt.swap_swa_sgd() def on_epoch_end(self): assert isinstance(self, LightningBaseModule) if self.params.opt_reset_interval: if self.current_epoch % self.params.opt_reset_interval == 0: for opt in self.trainer.optimizers: opt.state = defaultdict(dict)