Config Anpassungen
This commit is contained in:
5
main.py
5
main.py
@@ -13,9 +13,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ml_lib.modules.utils import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.logging import Logger
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from util.config import MConfig
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
@@ -137,5 +136,5 @@ def run_lightning_loop(config_obj):
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
config = Config.read_namespace(args)
|
||||
config = MConfig.read_namespace(args)
|
||||
trained_model = run_lightning_loop(config)
|
||||
|
@@ -31,11 +31,14 @@ class BinaryClassifier(LightningBaseModule):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(BinaryClassifier, self).__init__(hparams)
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
# Additional parameters
|
||||
self.in_shape = ()
|
||||
|
||||
#
|
||||
|
||||
# Model Modules
|
||||
self.conv_1 = ConvModule(self.in_shape, 32, 5, )
|
||||
self.conv_2 = ConvModule(64)
|
||||
|
0
util/__init__.py
Normal file
0
util/__init__.py
Normal file
9
util/config.py
Normal file
9
util/config.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from ml_lib.utils.config import Config
|
||||
from models.binary_classifier import BinaryClassifier
|
||||
|
||||
|
||||
class MConfig(Config):
|
||||
|
||||
@property
|
||||
def model_map(self):
|
||||
return dict(BinaryClassifier=BinaryClassifier)
|
Reference in New Issue
Block a user