Variational Generator

This commit is contained in:
Si11ium
2020-03-10 16:59:51 +01:00
parent 21e7e31805
commit 1b5a7dc69e
10 changed files with 177 additions and 95 deletions

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
@@ -28,7 +28,10 @@ class Config(ConfigParser):
@property
def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
model_dict = dict(ConvHomDetector=ConvHomDetector,
CNNRouteGenerator=CNNRouteGeneratorModel,
CNNRouteGeneratorDiscriminated=CNNRouteGeneratorDiscriminated
)
try:
return model_dict[self.get('model', 'type')]
except KeyError as e: