import warnings warnings.filterwarnings('ignore', category=FutureWarning) import torch from dataset import DataContainer from viz.utils import search_for_weights, Printer from run_models import * device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_and_predict(path_like_element): # Define Loop to search for models and folder with visualizations splitpath = path_like_element.split(os.sep) base_dir = os.path.join(*splitpath[:4]) model = globals()[splitpath[2]] pretrained_model = model.load_from_metrics( weights_path=path_like_element, tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'), on_gpu=True if torch.cuda.is_available() else False, ) print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}') # Init model and freeze its weights ( for faster inference) pretrained_model = pretrained_model.to(device) pretrained_model.eval() pretrained_model.freeze() # Load the data for prediction # TODO!!!!!!!!!: # Hier müssen natürlich auch die date parameter geladen werden! # Muss ich die val-sets automatisch neu setzen, also immer auf refresh haben, wenn ich validieren möchte? # Was ist denn eigentlich mein Val Dataset? # Hab ich irgendwo eine ganze karte? # Wie sorge ich dafür, dass gewisse karten, also größenverhältnisse usw nicht überrepräsentiert sind? dataset = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device) # Do the inference # test_pred = [pretrained_model(test_sample)[:-1] for test_sample in dataloader][0] p = Printer(pretrained_model) # Important: # Use all given valdiation samples, even if they relate to differnt maps. This is important since we want to have a # view on the complete latent space, not just in relation to a single basemap, which would be a major bias. p.print_possible_latent_spaces(dataset, save=os.path.join(base_dir, f'latent_space'), cluster_by_motion=False) if __name__ == '__main__': path = 'output' search_for_weights(load_and_predict, path, file_type='latent')