diff --git a/main.py b/main.py index ac4232c..d09100d 100644 --- a/main.py +++ b/main.py @@ -13,9 +13,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from torch.utils.data import DataLoader from ml_lib.modules.utils import LightningBaseModule -from ml_lib.utils.config import Config from ml_lib.utils.logging import Logger -from ml_lib.utils.model_io import SavedLightningModels +from util.config import MConfig warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -137,5 +136,5 @@ def run_lightning_loop(config_obj): if __name__ == "__main__": - config = Config.read_namespace(args) + config = MConfig.read_namespace(args) trained_model = run_lightning_loop(config) diff --git a/models/binary_classifier.py b/models/binary_classifier.py index 3840a9e..ef55a39 100644 --- a/models/binary_classifier.py +++ b/models/binary_classifier.py @@ -31,11 +31,14 @@ class BinaryClassifier(LightningBaseModule): def __init__(self, hparams): super(BinaryClassifier, self).__init__(hparams) + self.criterion = nn.BCELoss() # Additional parameters self.in_shape = () + # + # Model Modules self.conv_1 = ConvModule(self.in_shape, 32, 5, ) self.conv_2 = ConvModule(64) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/config.py b/util/config.py new file mode 100644 index 0000000..fd23974 --- /dev/null +++ b/util/config.py @@ -0,0 +1,9 @@ +from ml_lib.utils.config import Config +from models.binary_classifier import BinaryClassifier + + +class MConfig(Config): + + @property + def model_map(self): + return dict(BinaryClassifier=BinaryClassifier)