Train Active
This commit is contained in:
@ -106,7 +106,7 @@ class ResidualModule(nn.Module):
|
||||
self.in_shape = in_shape
|
||||
module_paramters.update(in_shape=in_shape)
|
||||
self.activation = activation() if activation else lambda x: x
|
||||
self.residual_block = [module_class(**module_paramters) for _ in range(n)]
|
||||
self.residual_block = nn.ModuleList([module_class(**module_paramters) for _ in range(n)])
|
||||
assert self.in_shape == self.shape, f'The in_shape: {self.in_shape} - must match the out_shape: {self.shape}.'
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -133,12 +133,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -146,21 +140,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_end(self, outputs):
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
y_scores, y_true = [], []
|
||||
for output in outputs:
|
||||
y_scores.append(output['y_pred'])
|
||||
y_true.append(output['y_true'])
|
||||
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
# FIXME: What did this do do i need it?
|
||||
# y_true = (y_true != V.HOMOTOPIC).long()
|
||||
y_scores = torch.cat(y_scores, dim=0)
|
||||
|
||||
roc_auc_scores = roc_auc_score(y_true.cpu().numpy(), y_scores.cpu().numpy())
|
||||
print(f'AUC Score: {roc_auc_scores}')
|
||||
return {'roc_auc_scores': roc_auc_scores}
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self):
|
||||
def _weight_init(m):
|
||||
|
Reference in New Issue
Block a user