Config Anpassungen

This commit is contained in:
Si11ium
2020-04-15 17:15:51 +02:00
parent 427a6463cb
commit 76c0e6aa05
4 changed files with 14 additions and 3 deletions

View File

@@ -13,9 +13,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader
from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
from util.config import MConfig
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@@ -137,5 +136,5 @@ def run_lightning_loop(config_obj):
if __name__ == "__main__":
config = Config.read_namespace(args)
config = MConfig.read_namespace(args)
trained_model = run_lightning_loop(config)

View File

@@ -31,11 +31,14 @@ class BinaryClassifier(LightningBaseModule):
def __init__(self, hparams):
super(BinaryClassifier, self).__init__(hparams)
self.criterion = nn.BCELoss()
# Additional parameters
self.in_shape = ()
#
# Model Modules
self.conv_1 = ConvModule(self.in_shape, 32, 5, )
self.conv_2 = ConvModule(64)

0
util/__init__.py Normal file
View File

9
util/config.py Normal file
View File

@@ -0,0 +1,9 @@
from ml_lib.utils.config import Config
from models.binary_classifier import BinaryClassifier
class MConfig(Config):
@property
def model_map(self):
return dict(BinaryClassifier=BinaryClassifier)