Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:09 +01:00
parent 7edd3834a1
commit ad254dae92
14 changed files with 679 additions and 134 deletions

View File

@ -1,8 +1,6 @@
import inspect
from argparse import Namespace
import variables as v
from torch import nn
from ml_lib.metrics.multi_class_classification import MultiClassScores
@ -17,7 +15,7 @@ class CNNBaseline(CombinedModelMixins,
):
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
filters):
filters, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval, loss):
# TODO: Move this to parent class, or make it much easieer to access....
a = dict(locals())
@ -50,4 +48,4 @@ class CNNBaseline(CombinedModelMixins,
return Namespace(main_out=tensor)
def additional_scores(self, outputs):
return MultiClassScores(self)
return MultiClassScores(self)(outputs)