Transformer running
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.optim import Adam
|
||||
from torch.optim import Adam, AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
|
||||
from torchcontrib.optim import SWA
|
||||
|
||||
@@ -20,12 +20,12 @@ class OptimizerMixin:
|
||||
# 'monitor': 'mean_val_loss' # Metric to monitor
|
||||
)
|
||||
|
||||
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
optimizer_dict.update(optimizer=optimizer)
|
||||
if self.params.lr_warmup_steps:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
|
||||
if self.params.lr_warm_restart_epochs:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warm_restart_epochs)
|
||||
else:
|
||||
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
|
||||
optimizer_dict.update(lr_scheduler=scheduler)
|
||||
@@ -42,4 +42,4 @@ class OptimizerMixin:
|
||||
if self.params.opt_reset_interval:
|
||||
if self.current_epoch % self.params.opt_reset_interval == 0:
|
||||
for opt in self.trainer.optimizers:
|
||||
opt.state = defaultdict(dict)
|
||||
opt.state = defaultdict(dict)
|
||||
|
||||
Reference in New Issue
Block a user