Fixed the Model classes, Visualization

This commit is contained in:
Si11ium
2019-08-23 13:10:47 +02:00
parent 0e879bfdb1
commit 7b0b96eaa3
16 changed files with 141 additions and 469 deletions

View File

@ -1,21 +1,17 @@
# TODO: THIS
import seaborn as sb
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import data_loader
from dataset import DataContainer
from collections import defaultdict
from tqdm import tqdm
import os
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns; sns.set()
import seaborn as sns
import matplotlib.pyplot as plt
from run_models import *
sns.set()
def search_for_weights(folder):
while not os.path.exists(folder):
if len(os.path.split(folder)) >= 50:
@ -32,6 +28,8 @@ def search_for_weights(folder):
def load_and_predict(path_like_element):
if any([x.name.endswith('.png') for x in os.scandir(os.path.dirname(path_like_element))]):
return
# Define Loop to search for models and folder with visualizations
model = globals()[path_like_element.path.split(os.sep)[-3]]
@ -46,36 +44,50 @@ def load_and_predict(path_like_element):
pretrained_model.eval()
pretrained_model.freeze()
# Load the data for prediction
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
with torch.no_grad():
# Do the inference
prediction_dict = defaultdict(list)
for i in tqdm(range(len(dataset)), total=len(dataset)):
p_X = pretrained_model(dataset[i].unsqueeze(0))
for idx in range(len(p_X) - 1):
prediction_dict[idx].append(p_X[idx])
# Load the data for prediction
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
# Do the inference
prediction_dict = defaultdict(list)
for i in tqdm(range(len(dataset)), total=len(dataset)):
p_X = pretrained_model(dataset[i].unsqueeze(0))
for idx in range(len(p_X) - 1):
prediction_dict[idx].append(p_X[idx])
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
for prediction in predictions:
viz_latent(prediction)
for idx, prediction in enumerate(predictions):
plot, _ = viz_latent(prediction)
plot.savefig(os.path.join(os.path.dirname(path_like_element), f'latent_space_{idx}.png'))
def viz_latent(prediction):
def viz_latent(prediction, title=f'Latent Space '):
if prediction.shape[-1] <= 1:
raise ValueError('How did this happen?')
elif prediction.shape[-1] == 2:
ax = sns.scatterplot(x=prediction[:, 0], y=prediction[:, 1])
plt.show()
return ax
try:
plt.show()
except:
pass
return ax.figure, (ax)
else:
fig, axs = plt.subplots(ncols=2)
predictions_pca = PCA(n_components=2)
predictions_tsne = TSNE(n_components=2)
pca_plot = sns.scatterplot(x=predictions_pca[:, 0], y=predictions_pca[:, 1], ax=axs[0])
tsne_plot = sns.scatterplot(x=predictions_tsne[:, 0], y=predictions_tsne[:, 1], ax=axs[1])
plt.show()
return fig, axs, pca_plot, tsne_plot
plots = []
for idx, dim_reducer in enumerate([PCA, TSNE]):
predictions_reduced = dim_reducer(n_components=2).fit_transform(prediction)
plot = sns.scatterplot(x=predictions_reduced[:, 0], y=predictions_reduced[:, 1],
ax=axs[idx])
plot.set_title(dim_reducer.__name__)
plots.append(plot)
try:
plt.show()
except:
pass
return fig, (*plots, )
if __name__ == '__main__':
path = 'output'