ae_toolbox_torch/viz/viz_prediction_in_map.py
2019-09-29 09:37:30 +02:00

56 lines
1.9 KiB
Python

from argparse import ArgumentParser
import os
from dataset import DataContainer
from viz.utils import MotionAnalyser, Printer, MapContainer, search_for_weights
import torch
arguments = ArgumentParser()
arguments.add_argument('--data', default='output')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from viz.utils import *
from run_models import *
arguments = ArgumentParser()
arguments.add_argument('--data', default='output')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_and_viz(path_like_element):
# Define Loop to search for models and folder with visualizations
splitpath = path_like_element.split(os.sep)
base_dir = os.path.join(*splitpath[:4])
model = globals()[splitpath[2]]
print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}')
pretrained_model = model.load_from_metrics(
weights_path=path_like_element,
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
on_gpu=True if torch.cuda.is_available() else False,
# map_location=None
)
# Init model and freeze its weights ( for faster inference)
pretrained_model = pretrained_model.to(device)
pretrained_model.eval()
pretrained_model.freeze()
dataIndex = 0
datasets = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
dataset = datasets.datasets[dataIndex]
# ToDO: use dataloader for iteration instead! - dataloader = DataLoader(dataset, )
maps = MapContainer(os.path.join(os.pardir, 'data', 'validation'))
base_map = maps.datasets[dataIndex]
p = Printer(pretrained_model)
p.print_trajec_on_basemap(dataset, base_map, save=os.path.join(base_dir, f'{base_map.name}_map.png'))
if __name__ == '__main__':
args = arguments.parse_args()
search_for_weights(load_and_viz, args.data, file_type='map')