Variational Generator

This commit is contained in:
Si11ium
2020-03-10 16:59:51 +01:00
parent 21e7e31805
commit 1b5a7dc69e
10 changed files with 177 additions and 95 deletions

View File

@@ -5,8 +5,10 @@ from typing import Union, List
import torch import torch
from random import choice from random import choice
from torch.utils.data import ConcatDataset, Dataset from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from lib.objects.map import Map from lib.objects.map import Map
import lib.variables as V
from PIL import Image from PIL import Image
@@ -36,13 +38,10 @@ class TrajDataset(Dataset):
if self.mode.lower() == 'just_route': if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory() trajectory = self.map.get_random_trajectory()
trajectory_space = trajectory.draw_in_array(self.map.shape)
label = choice([0, 1]) label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float() map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label return (map_array, trajectory_space), label
while True: while True:
trajectory = self.map.get_random_trajectory() trajectory = self.map.get_random_trajectory()
@@ -55,13 +54,13 @@ class TrajDataset(Dataset):
self.last_label = label self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']: if self.mode.lower() in ['all_in_map', 'separated_arrays']:
map_array = torch.as_tensor(self.map.as_array).float() map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
if self.mode == 'separated_arrays': if self.mode == 'separated_arrays':
return (map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(), int(label)), \ return (map_array, trajectory, label), alternative
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float()
else: else:
return torch.cat((map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(), return np.concatenate((map_array, trajectory, alternative)), label
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float())), int(label)
elif self.mode == 'vectors': elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname return trajectory.vertices, alternative.vertices, label, self.mapname

View File

@@ -25,59 +25,33 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return Adam(self.parameters(), lr=self.hparams.train_param.lr) return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy batch_x, alternative = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
generated_alternative, z, mu, logvar = self(batch_x + [label, ]) mse_loss = self.criterion(generated_alternative, alternative)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper: # see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114 # https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing # Dimensional Resizing TODO: Does This make sense? Sanity Check it!
kld_loss /= reduce(mul, self.in_shape) kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2 loss = (kld_loss + mse_loss) / 2
return dict(loss=loss, log=dict(loss=loss, return dict(loss=loss, log=dict(loss=loss, mse_loss=mse_loss, kld_loss=kld_loss))
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args): def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy batch_x, alternative = batch_xy
map_array, trajectory, label = batch_x
generated_alternative, z, mu, logvar = self(batch_x + [label, ]) generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False): def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test: if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf()) self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
plt.clf() plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random() maps, trajectories, labels, val_restul_dict = self.generate_random()
@@ -87,7 +61,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
fig = g.draw() fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig) self.logger.log_image(f'{self.name}_Output', fig)
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) return dict(epoch=self.current_epoch)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def test_step(self, *args): def test_step(self, *args):
return self._test_val_step(*args) return self._test_val_step(*args)
@@ -95,31 +75,20 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True) return self._test_val_epoch_end(outputs, test=True)
def __init__(self, *params, issubclassed=False):
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorModel, self).__init__(*params) super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset if not issubclassed:
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', # Dataset
length=self.hparams.data_param.dataset_length) self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
length=self.hparams.data_param.dataset_length)
self.criterion = nn.MSELoss()
# Additional Attributes # Additional Attributes
self.in_shape = self.dataset.map_shapes_max self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters # Todo: Better naming and size in Parameters
self.feature_dim = 10 self.feature_dim = 10
self.lat_dim = self.feature_dim + self.feature_dim + 1 self.lat_dim = self.feature_dim + self.feature_dim + 1
self._disc = None
# NN Nodes # NN Nodes
################################################### ###################################################
@@ -127,7 +96,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# Utils # Utils
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()
# #
# Map Encoder # Map Encoder
@@ -222,7 +190,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
alt_tensor = self.alt_deconv_2(alt_tensor) alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor) alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor) alt_tensor = self.alt_deconv_out(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor) # alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor return alt_tensor
def encode(self, map_array, trajectory, label): def encode(self, map_array, trajectory, label):
@@ -266,4 +234,100 @@ class CNNRouteGeneratorModel(LightningBaseModule):
maps = self._move_to_model_device(torch.stack(maps)) maps = self._move_to_model_device(torch.stack(maps))
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n)) labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999) return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
name = 'CNNRouteGeneratorDiscriminated'
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig)
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
length=self.hparams.data_param.dataset_length)

View File

@@ -32,24 +32,35 @@ class ConvHomDetector(LightningBaseModule):
pred_y = self(batch_x) pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb) return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
def validation_step(self, batch_xy, batch_nb, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True) return self._val_test_end(outputs)
def validation_epoch_end(self, outputs: list):
return self._val_test_end(outputs)
def _val_test_end(self, outputs, test=True):
evaluation = ROCEvaluation(plot_roc=True if test else False)
predictions = torch.cat([x['prediction'] for x in outputs]) predictions = torch.cat([x['prediction'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
# Sci-py call ROC eval call is eval(true_label, prediction) # Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), ) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy())
score_dict = dict(roc_auc=roc_auc)
# self.logger.log_metrics(score_dict) # self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}', plt.gcf()) if test:
self.logger.log_image(f'{self.name}', plt.gcf())
return dict(log=score_dict) return dict(score=roc_auc, log=dict(roc_auc=roc_auc))
def __init__(self, hparams): def __init__(self, hparams):
super(ConvHomDetector, self).__init__(hparams) super(ConvHomDetector, self).__init__(hparams)
# Dataset # Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map') self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map', )
# Additional Attributes # Additional Attributes
self.map_shape = self.dataset.map_shapes_max self.map_shape = self.dataset.map_shapes_max
@@ -59,6 +70,7 @@ class ConvHomDetector(LightningBaseModule):
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}' assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
self.criterion = nn.BCELoss() self.criterion = nn.BCELoss()
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
self.relu = nn.ReLU()
# NN Nodes # NN Nodes
# ============================ # ============================
@@ -100,6 +112,7 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.map_conv_3(tensor) tensor = self.map_conv_3(tensor)
tensor = self.flatten(tensor) tensor = self.flatten(tensor)
tensor = self.linear(tensor) tensor = self.linear(tensor)
tensor = self.relu(tensor)
tensor = self.classifier(tensor) tensor = self.classifier(tensor)
tensor = self.sigmoid(tensor) tensor = self.sigmoid(tensor)
return tensor return tensor

View File

@@ -17,7 +17,7 @@ class ConvModule(nn.Module):
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True, def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
dropout: Union[int, float] = 0, conv_class=nn.Conv2d, dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0): conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()

View File

@@ -154,7 +154,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Validation Dataloader # Validation Dataloader
def val_dataloader(self): def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False, return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.train_param.batch_size, batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker) num_workers=self.hparams.data_param.worker)

View File

@@ -18,10 +18,6 @@ import lib.variables as V
class Map(object): class Map(object):
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
white = 255
black = 0
def __copy__(self): def __copy__(self):
return copy.deepcopy(self) return copy.deepcopy(self)
@@ -51,6 +47,7 @@ class Map(object):
def __init__(self, name='', array_like_map_representation=None): def __init__(self, name='', array_like_map_representation=None):
if array_like_map_representation is not None: if array_like_map_representation is not None:
array_like_map_representation = array_like_map_representation.astype(np.float32)
if array_like_map_representation.ndim == 2: if array_like_map_representation.ndim == 2:
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0) array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
assert array_like_map_representation.ndim == 3 assert array_like_map_representation.ndim == 3
@@ -70,7 +67,7 @@ class Map(object):
# Check pixels for their color (determine if walkable) # Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array): for idx, value in np.ndenumerate(self.map_array):
if value != self.black: if value != V.BLACK:
# IF walkable, add node # IF walkable, add node
graph.add_node(idx, count=0) graph.add_node(idx, count=0)
# Fully connect to all surrounding neighbors # Fully connect to all surrounding neighbors
@@ -91,10 +88,9 @@ class Map(object):
if image.mode != 'L': if image.mode != 'L':
image = image.convert('L') image = image.convert('L')
map_array = np.expand_dims(np.array(image), axis=0) map_array = np.expand_dims(np.array(image), axis=0)
map_array = np.where(np.asarray(map_array) == cls.white, 1, 0)
if embedding_size: if embedding_size:
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}' assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
embedding = np.zeros(embedding_size) embedding = np.full(embedding_size, V.BLACK)
embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array
map_array = embedding map_array = embedding
@@ -146,12 +142,15 @@ class Map(object):
polyline = trajectory.xy_vertices polyline = trajectory.xy_vertices
polyline.extend(reversed(other_trajectory.xy_vertices)) polyline.extend(reversed(other_trajectory.xy_vertices))
img = Image.new('L', (self.height, self.width), 0) img = Image.new('L', (self.height, self.width), color=V.WHITE)
draw = ImageDraw.Draw(img) draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=1, fill=1) draw.polygon(polyline, outline=V.BLACK, fill=V.BLACK)
a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum() binary_img = np.where(np.asarray(img).squeeze() == V.BLACK, 1, 0)
binary_map = np.where(self.as_2d_array == V.BLACK, 1, 0)
a = (binary_img * binary_map).sum()
if a: if a:
return V.ALTERNATIVE # Non-Homotoph return V.ALTERNATIVE # Non-Homotoph

View File

@@ -42,9 +42,9 @@ class Trajectory(object):
return list(zip(self._vertices[:-1], self._vertices[1:])) return list(zip(self._vertices[:-1], self._vertices[1:]))
def draw_in_array(self, shape): def draw_in_array(self, shape):
trajectory_space = np.zeros(shape) trajectory_space = np.zeros(shape).astype(np.float32)
for index in self.vertices: for index in self.vertices:
trajectory_space[index] = 1 trajectory_space[index] = V.WHITE
return trajectory_space return trajectory_space
@property @property

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
from lib.models.homotopy_classification.cnn_based import ConvHomDetector from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters from lib.utils.model_io import ModelParameters
@@ -28,7 +28,10 @@ class Config(ConfigParser):
@property @property
def model_class(self): def model_class(self):
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel) model_dict = dict(ConvHomDetector=ConvHomDetector,
CNNRouteGenerator=CNNRouteGeneratorModel,
CNNRouteGeneratorDiscriminated=CNNRouteGeneratorDiscriminated
)
try: try:
return model_dict[self.get('model', 'type')] return model_dict[self.get('model', 'type')]
except KeyError as e: except KeyError as e:

View File

@@ -3,3 +3,5 @@ _ROOT = Path('..')
HOMOTOPIC = 1 HOMOTOPIC = 1
ALTERNATIVE = 0 ALTERNATIVE = 0
WHITE = 255
BLACK = 0

10
main.py
View File

@@ -48,7 +48,7 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="generator_cnn", help="") main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
@@ -94,9 +94,9 @@ def run_lightning_loop(config_obj):
# Init # Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters) model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights() model.init_weights()
if model.name == 'CNNRouteGenerator': if model.name == 'CNNRouteGeneratorDiscriminated':
# ToDo: Make this dependent on the used seed # ToDo: Make this dependent on the used seed
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0') path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'trained')
disc_model = SavedLightningModels.load_checkpoint(path).restore() disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model) model.set_discriminator(disc_model)
@@ -112,7 +112,9 @@ def run_lightning_loop(config_obj):
checkpoint_callback=checkpoint_callback, checkpoint_callback=checkpoint_callback,
logger=logger, logger=logger,
fast_dev_run=config_obj.main.debug, fast_dev_run=config_obj.main.debug,
early_stop_callback=None early_stop_callback=None,
val_percent_check=0.10,
num_sanity_val_steps=1,
) )
# Train It # Train It