From 8a97f599060a5d15f202da5940ecf063c9abe94a Mon Sep 17 00:00:00 2001 From: Si11ium Date: Mon, 27 Apr 2020 17:30:44 +0200 Subject: [PATCH] Model Running TODO: Redo the Dataset Label Processing --- main.py | 4 ++-- models/binary_classifier.py | 25 ++++++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 1fc1a73..05f268e 100644 --- a/main.py +++ b/main.py @@ -57,8 +57,8 @@ main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="") -main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="") +main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--model_norm", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") # Project Parameters diff --git a/models/binary_classifier.py b/models/binary_classifier.py index b636c1c..11cabfe 100644 --- a/models/binary_classifier.py +++ b/models/binary_classifier.py @@ -7,16 +7,22 @@ from ml_lib.modules.utils import LightningBaseModule, Flatten class BinaryClassifier(LightningBaseModule): + def test_step(self, *args, **kwargs): + pass + + def test_epoch_end(self, outputs): + pass + @classmethod def name(cls): return cls.__name__ def configure_optimizers(self): - return Adam(params=self.Parameters, lr=self.hparams.train.lr) + return Adam(params=self.parameters(), lr=self.hparams.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, batch_y = batch_xy - y = self(batch_y) + y = self(batch_x) loss = self.criterion(y, batch_y) return dict(loss=loss) @@ -41,26 +47,27 @@ class BinaryClassifier(LightningBaseModule): self.in_shape = self.hparams.in_shape # Model Modules - self.conv_1 = ConvModule(self.in_shape, 32, 5, conv_stride=4, **hparams) - self.conv_2 = ConvModule(self.conv_1.shape, 64, 7, conv_stride=2, **hparams) - self.conv_3 = ConvModule(self.conv_2.shape, 128, 9, conv_stride=2, **hparams) + self.conv_1 = ConvModule(self.in_shape, 32, 3, conv_stride=2, **self.hparams.module_paramters) + self.conv_2 = ConvModule(self.conv_1.shape, 64, 5, conv_stride=2, **self.hparams.module_paramters) + self.conv_3 = ConvModule(self.conv_2.shape, 128, 7, conv_stride=2, **self.hparams.module_paramters) self.flat = Flatten(self.conv_3.shape) - self.full_1 = nn.Linear(self.flat.shape, 32) - self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2) + self.full_1 = nn.Linear(self.flat.shape, 32, self.hparams.bias) + self.full_2 = nn.Linear(self.full_1.out_features, self.full_1.out_features // 2, self.hparams.bias) self.activation = self.hparams.activation() - self.full_out = nn.Linear(self.full_2.out_features, 2) + self.full_out = nn.Linear(self.full_2.out_features, 1, self.hparams.bias) self.sigmoid = nn.Sigmoid() def forward(self, batch, **kwargs): tensor = self.conv_1(batch) tensor = self.conv_2(tensor) tensor = self.conv_3(tensor) + tensor = self.flat(tensor) tensor = self.full_1(tensor) tensor = self.activation(tensor) tensor = self.full_2(tensor) tensor = self.activation(tensor) tensor = self.full_out(tensor) tensor = self.sigmoid(tensor) - return batch + return tensor