Visualization approach 1
This commit is contained in:
30
viz/utils.py
Normal file
30
viz/utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
|
||||
def search_for_weights(func, folder):
|
||||
while not os.path.exists(folder):
|
||||
if len(os.path.split(folder)) >= 50:
|
||||
raise FileNotFoundError(f'The folder "{folder}" could not be found')
|
||||
folder = os.path.join(os.pardir, folder)
|
||||
|
||||
if any([x.name.endswith('.png') for x in os.scandir(folder)]):
|
||||
return
|
||||
|
||||
if any(['.ckpt' in element.name and element.is_dir() for element in os.scandir(folder)]):
|
||||
_, _, filenames = next(os.walk(os.path.join(folder, 'weights.ckpt')))
|
||||
filenames.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
|
||||
func(os.path.join(folder, 'weights.ckpt', filenames[-1]))
|
||||
return
|
||||
|
||||
for element in os.scandir(folder):
|
||||
if os.path.exists(element):
|
||||
if element.is_dir():
|
||||
search_for_weights(func, element.path)
|
||||
elif element.is_file() and element.name.endswith('.ckpt'):
|
||||
func(element)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('This file should not be called.')
|
@ -1,6 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
@ -12,57 +9,52 @@ from run_models import *
|
||||
sns.set()
|
||||
|
||||
|
||||
def search_for_weights(folder):
|
||||
while not os.path.exists(folder):
|
||||
if len(os.path.split(folder)) >= 50:
|
||||
raise FileNotFoundError(f'The folder "{folder}" could not be found')
|
||||
folder = os.path.join(os.pardir, folder)
|
||||
for element in os.scandir(folder):
|
||||
if os.path.exists(element):
|
||||
if element.is_dir():
|
||||
search_for_weights(element.path)
|
||||
elif element.is_file() and element.name.endswith('.ckpt'):
|
||||
load_and_predict(element)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
def load_and_predict(path_like_element):
|
||||
if any([x.name.endswith('.png') for x in os.scandir(os.path.dirname(path_like_element))]):
|
||||
return
|
||||
|
||||
# Define Loop to search for models and folder with visualizations
|
||||
model = globals()[path_like_element.path.split(os.sep)[-3]]
|
||||
splitpath = path_like_element.split(os.sep)
|
||||
base_dir = os.path.join(*splitpath[:4])
|
||||
model = globals()[splitpath[2]]
|
||||
print(f'... loading model named: "{Model.name}" from timestamp: {splitpath[3]}')
|
||||
pretrained_model = model.load_from_metrics(
|
||||
weights_path=path_like_element.path,
|
||||
tags_csv=os.path.join(os.path.dirname(path_like_element), 'default', 'version_0', 'meta_tags.csv'),
|
||||
weights_path=path_like_element,
|
||||
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
|
||||
on_gpu=True if torch.cuda.is_available() else False,
|
||||
map_location=None
|
||||
)
|
||||
|
||||
# Init model and freeze its weights ( for faster inference)
|
||||
pretrained_model = pretrained_model.to(device)
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
# Load the data for prediction
|
||||
dataset = DataContainer(os.path.join(os.pardir, 'data'), 5, 5)
|
||||
|
||||
# TODO!!!!!!!!!:
|
||||
# Hier müssen natürlich auch die date parameter geladen werden!
|
||||
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
|
||||
# Was ist denn eigentlich mein Val Dataset?
|
||||
# Hab ich irgendwo eine ganze karte?
|
||||
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
|
||||
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
|
||||
|
||||
# Do the inference
|
||||
prediction_dict = defaultdict(list)
|
||||
for i in tqdm(range(len(dataset)), total=len(dataset)):
|
||||
p_X = pretrained_model(dataset[i].unsqueeze(0))
|
||||
for idx in range(len(p_X) - 1):
|
||||
prediction_dict[idx].append(p_X[idx])
|
||||
test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
|
||||
|
||||
predictions = [torch.cat(prediction).detach().numpy() for prediction in prediction_dict.values()]
|
||||
for idx, prediction in enumerate(predictions):
|
||||
for idx, prediction in enumerate(test_pred):
|
||||
plot, _ = viz_latent(prediction)
|
||||
plot.savefig(os.path.join(os.path.dirname(path_like_element), f'latent_space_{idx}.png'))
|
||||
plot.savefig(os.path.join(base_dir, f'latent_space_{idx}.png'))
|
||||
|
||||
|
||||
def viz_latent(prediction):
|
||||
try:
|
||||
prediction = prediction.cpu()
|
||||
prediction = prediction.numpy()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if prediction.shape[-1] <= 1:
|
||||
raise ValueError('How did this happen?')
|
||||
elif prediction.shape[-1] == 2:
|
||||
@ -91,4 +83,4 @@ def viz_latent(prediction):
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = 'output'
|
||||
search_for_weights(path)
|
||||
search_for_weights(search_for_weights, path)
|
50
viz/viz_map.py
Normal file
50
viz/viz_map.py
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
from dataset import *
|
||||
# Plotting
|
||||
# import matplotlib as mlp
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.patches import Polygon
|
||||
from matplotlib.collections import LineCollection, PatchCollection
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.cm as cmaps
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
import seaborn as sns
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from viz.utils import search_for_weights
|
||||
|
||||
from run_models import *
|
||||
|
||||
sns.set()
|
||||
|
||||
|
||||
arguments = ArgumentParser()
|
||||
arguments.add_argument('--data', default=os.path.join('data', 'validation'))
|
||||
|
||||
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
|
||||
|
||||
|
||||
|
||||
def viz_map(self, base_map: MapContainer):
|
||||
# Base Map Plotting
|
||||
# filled Triangle
|
||||
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
|
||||
patch_collection = PatchCollection(patches, color='k')
|
||||
|
||||
self.ax.add_collection(patch_collection)
|
||||
print('Basemap Plotted')
|
||||
|
||||
patches = [Polygon(base_map.get_triangle_by_key(i), True, color='k') for i in range(len(base_map))]
|
||||
return PatchCollection(patches, color='k')
|
||||
|
||||
def load_and_predict(folder):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
search_for_weights(load_and_predict, arguments.data)
|
||||
# ToDo: THIS
|
Reference in New Issue
Block a user