Done: First VIsualization
ToDo: Visualization for all classes, latent space setups
This commit is contained in:
71
viz/viz_latent.py
Normal file
71
viz/viz_latent.py
Normal file
@ -0,0 +1,71 @@
|
||||
# TODO: THIS
|
||||
import seaborn as sb
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning import data_loader
|
||||
from dataset import DataContainer
|
||||
import os
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
import seaborn as sns; sns.set()
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from run_models import SeparatingAdversarialModel
|
||||
|
||||
path = 'output'
|
||||
mylightningmodule = 'weired name, loaded from disk'
|
||||
|
||||
|
||||
# FIXME: How to store hyperparamters in testtube element?
|
||||
|
||||
|
||||
def search_for_weights(folder):
|
||||
for element in os.scandir(folder):
|
||||
if os.path.exists(element):
|
||||
if element.is_dir():
|
||||
search_for_weights(element.path)
|
||||
elif element.is_file() and element.name.endswith('.ckpt'):
|
||||
load_and_viz(element)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
def load_and_viz(path_like_element):
|
||||
|
||||
# Define Loop to search for models and folder with visualizations
|
||||
pretrained_model = SeparatingAdversarialModel.load_from_metrics(
|
||||
weights_path=path_like_element.path,
|
||||
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
|
||||
on_gpu=True if torch.cuda.is_available() else False,
|
||||
map_location=None
|
||||
)
|
||||
|
||||
# Init model and freeze its weights ( for faster inference)
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
|
||||
# Load the data fpr prediction
|
||||
dataset = DataContainer('data', 5, 5)
|
||||
|
||||
# Do the inference
|
||||
predictions = []
|
||||
for i in range(len(dataset)):
|
||||
z, _ = pretrained_model(dataset[i])
|
||||
predictions.append(z)
|
||||
predictions = torch.cat(predictions)
|
||||
if predictions.shape[-1] <= 1:
|
||||
raise ValueError('How did this happen?')
|
||||
elif predictions.shape[-1] == 2:
|
||||
ax = sns.scatterplot(x=predictions[:, 0], y=predictions[:, 1])
|
||||
plt.show()
|
||||
return ax
|
||||
else:
|
||||
fig, axs = plt.subplots(ncols=2)
|
||||
predictions_pca = PCA(n_components=2)
|
||||
predictions_tsne = TSNE(n_components=2)
|
||||
pca_plot = sns.scatterplot(x=predictions_pca[:, 0], y=predictions_pca[:, 1], ax=axs[0])
|
||||
tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1])
|
||||
plt.show()
|
||||
return fig, axs, pca_plot, tsne_plot
|
Reference in New Issue
Block a user