adjustment fot CCS, notebook folder
This commit is contained in:
parent
ed260f1c2a
commit
d3e7bf7efb
@ -2,7 +2,6 @@ from itertools import cycle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch_lightning.metrics import Recall
|
||||
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
|
||||
recall_score
|
||||
|
||||
|
@ -35,8 +35,9 @@ try:
|
||||
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
||||
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
||||
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
|
||||
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
|
||||
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
|
||||
if self.n_classes > 2:
|
||||
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
|
||||
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(((name, metric) for name, metric in self._modules.items()))
|
||||
|
@ -71,10 +71,6 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||
auto_lr_find=not args['debug'],
|
||||
weights_summary='top',
|
||||
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1),
|
||||
limit_train_batches = 2.0,
|
||||
limit_val_batches = 2.0,
|
||||
limit_test_batches = 2.0,
|
||||
limit_predict_batches = 2.0,
|
||||
)
|
||||
|
||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||
|
Loading…
x
Reference in New Issue
Block a user