Train Active
This commit is contained in:
@ -18,21 +18,12 @@ class ConvHomDetector(LightningBaseModule):
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, batch_y = batch_xy
|
||||
pred_y = self(batch_x)
|
||||
loss = F.binary_cross_entropy(pred_y, batch_y)
|
||||
loss = F.binary_cross_entropy(pred_y, batch_y.float())
|
||||
return {'loss': loss, 'log': dict(loss=loss)}
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(ConvHomDetector, self).__init__(*params)
|
||||
|
||||
@ -75,8 +66,9 @@ class ConvHomDetector(LightningBaseModule):
|
||||
#
|
||||
|
||||
self.linear = nn.Linear(reduce(mul, self.flatten.shape), self.hparams.model_param.classes * 10)
|
||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, self.hparams.model_param.classes)
|
||||
self.softmax = nn.Softmax()
|
||||
# Comments on Multi Class labels
|
||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
||||
self.out_activation = nn.Sigmoid() # nn.Softmax
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.map_conv_0(x)
|
||||
@ -88,5 +80,5 @@ class ConvHomDetector(LightningBaseModule):
|
||||
tensor = self.flatten(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.classifier(tensor)
|
||||
tensor = self.softmax(tensor)
|
||||
tensor = self.out_activation(tensor)
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user