51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
import warnings
|
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
import torch
|
|
from dataset import DataContainer
|
|
from viz.utils import search_for_weights, Printer
|
|
from run_models import *
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
def load_and_predict(path_like_element):
|
|
# Define Loop to search for models and folder with visualizations
|
|
splitpath = path_like_element.split(os.sep)
|
|
base_dir = os.path.join(*splitpath[:4])
|
|
model = globals()[splitpath[2]]
|
|
pretrained_model = model.load_from_metrics(
|
|
weights_path=path_like_element,
|
|
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
|
|
on_gpu=True if torch.cuda.is_available() else False,
|
|
)
|
|
print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}')
|
|
|
|
# Init model and freeze its weights ( for faster inference)
|
|
pretrained_model = pretrained_model.to(device)
|
|
pretrained_model.eval()
|
|
pretrained_model.freeze()
|
|
|
|
# Load the data for prediction
|
|
|
|
# TODO!!!!!!!!!:
|
|
# Hier müssen natürlich auch die date parameter geladen werden!
|
|
# Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte?
|
|
# Was ist denn eigentlich mein Val Dataset?
|
|
# Hab ich irgendwo eine ganze karte?
|
|
# Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind?
|
|
dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
|
|
|
# Do the inference
|
|
# test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0]
|
|
|
|
p = Printer(pretrained_model)
|
|
# Important:
|
|
# Use all given valdiation samples, even if they relate to differnt maps. This is important since we want to have a
|
|
# view on the complete latent space, not just in relation to a single basemap, which would be a major bias.
|
|
p.print_possible_latent_spaces(dataset, save=os.path.join(base_dir, f'latent_space'), cluster_by_motion=False)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
path = 'output'
|
|
search_for_weights(load_and_predict, path, file_type='latent') |