Visualization approach n
This commit is contained in:
@ -1,12 +1,12 @@
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
import torch
|
||||
from dataset import DataContainer
|
||||
from viz.utils import search_for_weights, Printer
|
||||
from run_models import *
|
||||
|
||||
sns.set()
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def load_and_predict(path_like_element):
|
||||
@ -14,73 +14,39 @@ def load_and_predict(path_like_element):
|
||||
splitpath = path_like_element.split(os.sep)
|
||||
base_dir = os.path.join(*splitpath[:4])
|
||||
model = globals()[splitpath[2]]
|
||||
print(f'... loading model named: "{Model.name}" from timestamp: {splitpath[3]}')
|
||||
pretrained_model = model.load_from_metrics(
|
||||
weights_path=path_like_element,
|
||||
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
|
||||
on_gpu=True if torch.cuda.is_available() else False,
|
||||
map_location=None
|
||||
# map_location=None
|
||||
)
|
||||
print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}')
|
||||
|
||||
# Init model and freeze its weights ( for faster inference)
|
||||
pretrained_model = pretrained_model.to(device)
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
|
||||
with torch.no_grad():
|
||||
# Load the data for prediction
|
||||
|
||||
# Load the data for prediction
|
||||
# TODO!!!!!!!!!:
|
||||
# Hier müssen natürlich auch die date parameter geladen werden!
|
||||
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
|
||||
# Was ist denn eigentlich mein Val Dataset?
|
||||
# Hab ich irgendwo eine ganze karte?
|
||||
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
|
||||
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||
|
||||
# TODO!!!!!!!!!:
|
||||
# Hier müssen natürlich auch die date parameter geladen werden!
|
||||
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
|
||||
# Was ist denn eigentlich mein Val Dataset?
|
||||
# Hab ich irgendwo eine ganze karte?
|
||||
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
|
||||
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||
dataloader = DataLoader(dataset, shuffle=True, batch_size=len(dataset))
|
||||
# Do the inference
|
||||
# test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
|
||||
|
||||
# Do the inference
|
||||
test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
|
||||
|
||||
for idx, prediction in enumerate(test_pred):
|
||||
plot, _ = viz_latent(prediction)
|
||||
plot.savefig(os.path.join(base_dir, f'latent_space_{idx}.png'))
|
||||
|
||||
|
||||
def viz_latent(prediction):
|
||||
try:
|
||||
prediction = prediction.cpu()
|
||||
prediction = prediction.numpy()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if prediction.shape[-1] <= 1:
|
||||
raise ValueError('How did this happen?')
|
||||
elif prediction.shape[-1] == 2:
|
||||
ax = sns.scatterplot(x=prediction[:, 0], y=prediction[:, 1])
|
||||
try:
|
||||
plt.show()
|
||||
except:
|
||||
pass
|
||||
return ax.figure, (ax)
|
||||
else:
|
||||
fig, axs = plt.subplots(ncols=2)
|
||||
plots = []
|
||||
for idx, dim_reducer in enumerate([PCA, TSNE]):
|
||||
predictions_reduced = dim_reducer(n_components=2).fit_transform(prediction)
|
||||
plot = sns.scatterplot(x=predictions_reduced[:, 0], y=predictions_reduced[:, 1],
|
||||
ax=axs[idx])
|
||||
plot.set_title(dim_reducer.__name__)
|
||||
plots.append(plot)
|
||||
|
||||
try:
|
||||
plt.show()
|
||||
except:
|
||||
pass
|
||||
return fig, (*plots, )
|
||||
p = Printer(pretrained_model)
|
||||
# Important:
|
||||
# Use all given valdiation samples, even if they relate to differnt maps. This is important since we want to have a
|
||||
# view on the complete latent space, not just in relation to a single basemap, which would be a major bias.
|
||||
p.print_possible_latent_spaces(dataset, save=os.path.join(base_dir, f'latent_space'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = 'output'
|
||||
search_for_weights(search_for_weights, path)
|
||||
search_for_weights(load_and_predict, path, file_type='latent')
|
Reference in New Issue
Block a user