Train Active

This commit is contained in:
Si11ium
2020-03-03 15:10:17 +01:00
parent 44f6589259
commit 1f612a968c
13 changed files with 102 additions and 98 deletions

View File

@ -133,12 +133,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
def forward(self, *args, **kwargs):
raise NotImplementedError
def validation_step(self, *args, **kwargs):
raise NotImplementedError
def validation_end(self, outputs):
raise NotImplementedError
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
raise NotImplementedError
@ -146,21 +140,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
raise NotImplementedError
def test_end(self, outputs):
from sklearn.metrics import roc_auc_score
y_scores, y_true = [], []
for output in outputs:
y_scores.append(output['y_pred'])
y_true.append(output['y_true'])
y_true = torch.cat(y_true, dim=0)
# FIXME: What did this do do i need it?
# y_true = (y_true != V.HOMOTOPIC).long()
y_scores = torch.cat(y_scores, dim=0)
roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy())
print(f'AUC Score: {roc_auc_scores}')
return {'roc_auc_scores': roc_auc_scores}
raise NotImplementedError
def init_weights(self):
def _weight_init(m):