Train Active
This commit is contained in:
@ -133,12 +133,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -146,21 +140,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_end(self, outputs):
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
y_scores, y_true = [], []
|
||||
for output in outputs:
|
||||
y_scores.append(output['y_pred'])
|
||||
y_true.append(output['y_true'])
|
||||
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
# FIXME: What did this do do i need it?
|
||||
# y_true = (y_true != V.HOMOTOPIC).long()
|
||||
y_scores = torch.cat(y_scores, dim=0)
|
||||
|
||||
roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy())
|
||||
print(f'AUC Score: {roc_auc_scores}')
|
||||
return {'roc_auc_scores': roc_auc_scores}
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self):
|
||||
def _weight_init(m):
|
||||
|
Reference in New Issue
Block a user