Transformer running
This commit is contained in:
@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import variables as v
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
@ -17,7 +15,7 @@ class CNNBaseline(CombinedModelMixins,
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
|
||||
filters):
|
||||
filters, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval, loss):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access....
|
||||
a = dict(locals())
|
||||
@ -50,4 +48,4 @@ class CNNBaseline(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
Reference in New Issue
Block a user