Variational Generator
This commit is contained in:
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
|
||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||
from lib.utils.model_io import ModelParameters
|
||||
|
||||
@@ -28,7 +28,10 @@ class Config(ConfigParser):
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
|
||||
model_dict = dict(ConvHomDetector=ConvHomDetector,
|
||||
CNNRouteGenerator=CNNRouteGeneratorModel,
|
||||
CNNRouteGeneratorDiscriminated=CNNRouteGeneratorDiscriminated
|
||||
)
|
||||
try:
|
||||
return model_dict[self.get('model', 'type')]
|
||||
except KeyError as e:
|
||||
|
||||
Reference in New Issue
Block a user