adjustment fot CCS, notebook folder
This commit is contained in:
parent
ed260f1c2a
commit
d3e7bf7efb
@ -2,7 +2,6 @@ from itertools import cycle
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.metrics import Recall
|
|
||||||
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
|
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
|
||||||
recall_score
|
recall_score
|
||||||
|
|
||||||
|
@ -35,8 +35,9 @@ try:
|
|||||||
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
||||||
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
||||||
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
|
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
|
||||||
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
|
if self.n_classes > 2:
|
||||||
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
|
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
|
||||||
|
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(((name, metric) for name, metric in self._modules.items()))
|
return iter(((name, metric) for name, metric in self._modules.items()))
|
||||||
|
@ -71,10 +71,6 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
auto_lr_find=not args['debug'],
|
auto_lr_find=not args['debug'],
|
||||||
weights_summary='top',
|
weights_summary='top',
|
||||||
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1),
|
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1),
|
||||||
limit_train_batches = 2.0,
|
|
||||||
limit_val_batches = 2.0,
|
|
||||||
limit_test_batches = 2.0,
|
|
||||||
limit_predict_batches = 2.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user