ae_toolbox_torch/viz/viz_latent.py
2021-02-01 09:59:56 +01:00

51 lines
2.1 KiB
Python

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
import torch
from dataset import DataContainer
from viz.utils import search_for_weights, Printer
from run_models import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_and_predict(path_like_element):
# Define Loop to search for models and folder with visualizations
splitpath = path_like_element.split(os.sep)
base_dir = os.path.join(*splitpath[:4])
model = globals()[splitpath[2]]
pretrained_model = model.load_from_metrics(
weights_path=path_like_element,
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
on_gpu=True if torch.cuda.is_available() else False,
)
print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}')
# Init model and freeze its weights ( for faster inference)
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()
pretrained_model.freeze()
# Load the data for prediction
# TODO!!!!!!!!!:
# Hier müssen natürlich auch die date parameter geladen werden!
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
# Was ist denn eigentlich mein Val Dataset?
# Hab ich irgendwo eine ganze karte?
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
# Do the inference
# test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
p = Printer(pretrained_model)
# Important:
# Use all given valdiation samples, even if they relate to differnt maps. This is important since we want to have a
# view on the complete latent space, not just in relation to a single basemap, which would be a major bias.
p.print_possible_latent_spaces(dataset, save=os.path.join(base_dir, f'latent_space'), cluster_by_motion=False)
if __name__ == '__main__':
path = 'output'
search_for_weights(load_and_predict, path, file_type='latent')