
ToDo: Visualization for variational spaces Trajectory Coloring Post Processing Metric Slurm Skript
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
# TODO: THIS
|
|
import seaborn as sb
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from pytorch_lightning import data_loader
|
|
from dataset import DataContainer
|
|
from collections import defaultdict
|
|
from tqdm import tqdm
|
|
import os
|
|
|
|
from sklearn.manifold import TSNE
|
|
from sklearn.decomposition import PCA
|
|
|
|
import seaborn as sns; sns.set()
|
|
import matplotlib.pyplot as plt
|
|
|
|
from run_models import *
|
|
|
|
def search_for_weights(folder):
|
|
while not os.path.exists(folder):
|
|
if len(os.path.split(folder)) >= 50:
|
|
raise FileNotFoundError(f'The folder "{folder}" could not be found')
|
|
folder = os.path.join(os.pardir, folder)
|
|
for element in os.scandir(folder):
|
|
if os.path.exists(element):
|
|
if element.is_dir():
|
|
search_for_weights(element.path)
|
|
elif element.is_file() and element.name.endswith('.ckpt'):
|
|
load_and_predict(element)
|
|
else:
|
|
continue
|
|
|
|
|
|
def load_and_predict(path_like_element):
|
|
|
|
# Define Loop to search for models and folder with visualizations
|
|
model = globals()[path_like_element.path.split(os.sep)[-3]]
|
|
pretrained_model = model.load_from_metrics(
|
|
weights_path=path_like_element.path,
|
|
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
|
|
on_gpu=True if torch.cuda.is_available() else False,
|
|
map_location=None
|
|
)
|
|
|
|
# Init model and freeze its weights ( for faster inference)
|
|
pretrained_model.eval()
|
|
pretrained_model.freeze()
|
|
|
|
# Load the data for prediction
|
|
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
|
|
|
|
# Do the inference
|
|
prediction_dict = defaultdict(list)
|
|
for i in tqdm(range(len(dataset)), total=len(dataset)):
|
|
p_X = pretrained_model(dataset[i].unsqueeze(0))
|
|
for idx in range(len(p_X) - 1):
|
|
prediction_dict[idx].append(p_X[idx])
|
|
|
|
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
|
|
for prediction in predictions:
|
|
viz_latent(prediction)
|
|
|
|
|
|
def viz_latent(prediction):
|
|
if prediction.shape[-1] <= 1:
|
|
raise ValueError('How did this happen?')
|
|
elif prediction.shape[-1] == 2:
|
|
ax = sns.scatterplot(x=prediction[:, 0], y=prediction[:, 1])
|
|
plt.show()
|
|
return ax
|
|
else:
|
|
fig, axs = plt.subplots(ncols=2)
|
|
predictions_pca = PCA(n_components=2)
|
|
predictions_tsne = TSNE(n_components=2)
|
|
pca_plot = sns.scatterplot(x=predictions_pca[:, 0], y=predictions_pca[:, 1], ax=axs[0])
|
|
tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1])
|
|
plt.show()
|
|
return fig, axs, pca_plot, tsne_plot
|
|
|
|
if __name__ == '__main__':
|
|
path = 'output'
|
|
search_for_weights(path) |