Visualization approach 1

This commit is contained in:
Si11ium
2019-09-13 13:36:13 +02:00
parent 18305a9e7e
commit 1386cdfd33
9 changed files with 185 additions and 50 deletions

30
viz/utils.py Normal file
View File

@ -0,0 +1,30 @@
import os
def search_for_weights(func, folder):
while not os.path.exists(folder):
if len(os.path.split(folder)) >= 50:
raise FileNotFoundError(f'The folder "{folder}" could not be found')
folder = os.path.join(os.pardir, folder)
if any([x.name.endswith('.png') for x in os.scandir(folder)]):
return
if any(['.ckpt' in element.name and element.is_dir() for element in os.scandir(folder)]):
_, _, filenames = next(os.walk(os.path.join(folder, 'weights.ckpt')))
filenames.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
func(os.path.join(folder, 'weights.ckpt', filenames[-1]))
return
for element in os.scandir(folder):
if os.path.exists(element):
if element.is_dir():
search_for_weights(func, element.path)
elif element.is_file() and element.name.endswith('.ckpt'):
func(element)
else:
continue
if __name__ == '__main__':
raise PermissionError('This file should not be called.')

View File

@ -1,6 +1,3 @@
from collections import defaultdict
from tqdm import tqdm
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
@ -12,57 +9,52 @@ from run_models import *
sns.set()
def search_for_weights(folder):
while not os.path.exists(folder):
if len(os.path.split(folder)) >= 50:
raise FileNotFoundError(f'The folder "{folder}" could not be found')
folder = os.path.join(os.pardir, folder)
for element in os.scandir(folder):
if os.path.exists(element):
if element.is_dir():
search_for_weights(element.path)
elif element.is_file() and element.name.endswith('.ckpt'):
load_and_predict(element)
else:
continue
def load_and_predict(path_like_element):
if any([x.name.endswith('.png') for x in os.scandir(os.path.dirname(path_like_element))]):
return
# Define Loop to search for models and folder with visualizations
model = globals()[path_like_element.path.split(os.sep)[-3]]
splitpath = path_like_element.split(os.sep)
base_dir = os.path.join(*splitpath[:4])
model = globals()[splitpath[2]]
print(f'... loading model named: "{Model.name}" from timestamp: {splitpath[3]}')
pretrained_model = model.load_from_metrics(
weights_path=path_like_element.path,
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
weights_path=path_like_element,
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
on_gpu=True if torch.cuda.is_available() else False,
map_location=None
)
# Init model and freeze its weights ( for faster inference)
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()
pretrained_model.freeze()
with torch.no_grad():
# Load the data for prediction
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
# TODO!!!!!!!!!:
# Hier müssen natürlich auch die date parameter geladen werden!
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
# Was ist denn eigentlich mein Val Dataset?
# Hab ich irgendwo eine ganze karte?
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
# Do the inference
prediction_dict = defaultdict(list)
for i in tqdm(range(len(dataset)), total=len(dataset)):
p_X = pretrained_model(dataset[i].unsqueeze(0))
for idx in range(len(p_X) - 1):
prediction_dict[idx].append(p_X[idx])
test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
for idx, prediction in enumerate(predictions):
for idx, prediction in enumerate(test_pred):
plot, _ = viz_latent(prediction)
plot.savefig(os.path.join(os.path.dirname(path_like_element), f'latent_space_{idx}.png'))
plot.savefig(os.path.join(base_dir, f'latent_space_{idx}.png'))
def viz_latent(prediction):
try:
prediction = prediction.cpu()
prediction = prediction.numpy()
except AttributeError:
pass
if prediction.shape[-1] <= 1:
raise ValueError('How did this happen?')
elif prediction.shape[-1] == 2:
@ -91,4 +83,4 @@ def viz_latent(prediction):
if __name__ == '__main__':
path = 'output'
search_for_weights(path)
search_for_weights(search_for_weights, path)

50
viz/viz_map.py Normal file
View File

@ -0,0 +1,50 @@
from dataset import *
# Plotting
# import matplotlib as mlp
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import LineCollection, PatchCollection
import matplotlib.colors as mcolors
import matplotlib.cm as cmaps
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns
from argparse import ArgumentParser
from viz.utils import search_for_weights
from run_models import *
sns.set()
arguments = ArgumentParser()
arguments.add_argument('--data', default=os.path.join('data', 'validation'))
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
def viz_map(self, base_map: MapContainer):
# Base Map Plotting
# filled Triangle
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
patch_collection = PatchCollection(patches, color='k')
self.ax.add_collection(patch_collection)
print('Basemap Plotted')
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
return PatchCollection(patches, color='k')
def load_and_predict(folder):
pass
if __name__ == '__main__':
search_for_weights(load_and_predict, arguments.data)
# ToDo: THIS