project Refactor, CNN Classifier Basics

This commit is contained in:
Steffen Illium
2020-03-08 23:46:02 +01:00
parent 75e8a61628
commit cd4fdf2de3
20 changed files with 441 additions and 239 deletions

View File

@@ -5,6 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
@@ -27,7 +28,7 @@ class Config(ConfigParser):
@property
def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector)
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
try:
return model_dict[self.get('model', 'type')]
except KeyError as e:

View File

@@ -1,8 +1,8 @@
from pathlib import Path
from pytorch_lightning.logging.base import LightningLoggerBase
from pytorch_lightning.logging.neptune import NeptuneLogger
from pytorch_lightning.logging.test_tube import TestTubeLogger
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from lib.utils.config import Config

View File

@@ -1,5 +1,7 @@
from argparse import Namespace
from pathlib import Path
import torch
from natsort import natsorted
from torch import nn
@@ -35,30 +37,25 @@ class ModelParameters(Namespace):
class SavedLightningModels(object):
@classmethod
def load_checkpoint(cls, models_root_path, model, n=-1, tags_file_path=''):
def load_checkpoint(cls, models_root_path, model=None, n=-1, tags_file_path=''):
assert models_root_path.exists(), f'The path {models_root_path.absolute()} does not exist!'
found_checkpoints = list(Path(models_root_path).rglob('*.ckpt'))
found_checkpoints = natsorted(found_checkpoints, key=lambda y: y.name)
if model is None:
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
if not tags_file_path:
tag_files = models_root_path.rglob('meta_tags.csv')
tags_file_path = list(tag_files)[0]
return cls(weights=found_checkpoints[n], model=model, tags=tags_file_path)
return cls(weights=found_checkpoints[n], model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.tags: str = kwargs.get('tags', '')
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_metrics(
weights_path=self.weights,
tags_csv=self.tags
)
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model