Merge remote-tracking branch 'origin/master'

# Conflicts:
#	models/transformer_model.py
#	multi_run.py
This commit is contained in:
Steffen
2021-03-19 17:19:08 +01:00
5 changed files with 19 additions and 14 deletions

View File

@@ -152,7 +152,7 @@ class TestMixin:
class_names = {val: key for val, key in ['negative', 'positive']}
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
prediction=y_max.cpu().numpy()))
prediction=[class_names[x.item()] for x in y_max.cpu()]))
result_file = Path(self.logger.log_dir / 'predictions.csv')
if result_file.exists():
try:

View File

@@ -26,7 +26,7 @@ class OptimizerMixin:
optimizer_dict.update(optimizer=optimizer)
if self.params.scheduler == CosineAnnealingWarmRestarts.__name__:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_scheduler_parameter)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=self.params.lr_scheduler_parameter)
elif self.params.scheduler == LambdaLR.__name__:
lr_reduce_ratio = self.params.lr_scheduler_parameter
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: lr_reduce_ratio ** epoch)