Done: Latent Space Viz

ToDo: Visualization for variational spaces
Trajectory Coloring
Post Processing
Metric
Slurm Skript
This commit is contained in:
Si11ium
2019-08-23 09:54:00 +02:00
parent 744c0c50b7
commit 1a0400d736
9 changed files with 159 additions and 76 deletions

View File

@ -4,6 +4,8 @@ import torch
from torch.utils.data import DataLoader
from pytorch_lightning import data_loader
from dataset import DataContainer
from collections import defaultdict
from tqdm import tqdm
import os
from sklearn.manifold import TSNE
@ -12,30 +14,28 @@ from sklearn.decomposition import PCA
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
from run_models import SeparatingAdversarialModel
path = 'output'
mylightningmodule = 'weired name, loaded from disk'
# FIXME: How to store hyperparamters in testtube element?
from run_models import *
def search_for_weights(folder):
while not os.path.exists(folder):
if len(os.path.split(folder)) >= 50:
raise FileNotFoundError(f'The folder "{folder}" could not be found')
folder = os.path.join(os.pardir, folder)
for element in os.scandir(folder):
if os.path.exists(element):
if element.is_dir():
search_for_weights(element.path)
elif element.is_file() and element.name.endswith('.ckpt'):
load_and_viz(element)
load_and_predict(element)
else:
continue
def load_and_viz(path_like_element):
def load_and_predict(path_like_element):
# Define Loop to search for models and folder with visualizations
pretrained_model = SeparatingAdversarialModel.load_from_metrics(
model = globals()[path_like_element.path.split(os.sep)[-3]]
pretrained_model = model.load_from_metrics(
weights_path=path_like_element.path,
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
on_gpu=True if torch.cuda.is_available() else False,
@ -46,19 +46,26 @@ def load_and_viz(path_like_element):
pretrained_model.eval()
pretrained_model.freeze()
# Load the data fpr prediction
dataset = DataContainer('data', 5, 5)
# Load the data for prediction
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
# Do the inference
predictions = []
for i in range(len(dataset)):
z, _ = pretrained_model(dataset[i])
predictions.append(z)
predictions = torch.cat(predictions)
if predictions.shape[-1] <= 1:
prediction_dict = defaultdict(list)
for i in tqdm(range(len(dataset)), total=len(dataset)):
p_X = pretrained_model(dataset[i].unsqueeze(0))
for idx in range(len(p_X) - 1):
prediction_dict[idx].append(p_X[idx])
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
for prediction in predictions:
viz_latent(prediction)
def viz_latent(prediction):
if prediction.shape[-1] <= 1:
raise ValueError('How did this happen?')
elif predictions.shape[-1] == 2:
ax = sns.scatterplot(x=predictions[:, 0], y=predictions[:, 1])
elif prediction.shape[-1] == 2:
ax = sns.scatterplot(x=prediction[:, 0], y=prediction[:, 1])
plt.show()
return ax
else:
@ -69,3 +76,7 @@ def load_and_viz(path_like_element):
tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1])
plt.show()
return fig, axs, pca_plot, tsne_plot
if __name__ == '__main__':
path = 'output'
search_for_weights(path)