Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:09 +01:00
parent 7edd3834a1
commit ad254dae92
14 changed files with 679 additions and 134 deletions

View File

@@ -1,6 +1,6 @@
from collections import defaultdict
from torch.optim import Adam
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LambdaLR
from torchcontrib.optim import SWA
@@ -20,12 +20,12 @@ class OptimizerMixin:
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
optimizer = AdamW(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
if self.params.lr_warm_restart_epochs:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warm_restart_epochs)
else:
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)
optimizer_dict.update(lr_scheduler=scheduler)
@@ -42,4 +42,4 @@ class OptimizerMixin:
if self.params.opt_reset_interval:
if self.current_epoch % self.params.opt_reset_interval == 0:
for opt in self.trainer.optimizers:
opt.state = defaultdict(dict)
opt.state = defaultdict(dict)