Visualization approach n
This commit is contained in:
50
viz/print_movement_in_map.py
Normal file
50
viz/print_movement_in_map.py
Normal file
@ -0,0 +1,50 @@
|
||||
from argparse import ArgumentParser
|
||||
import os
|
||||
|
||||
from dataset import DataContainer
|
||||
from viz.utils import MotionAnalyser, Printer, MapContainer, search_for_weights
|
||||
import torch
|
||||
from run_models import SAAE_Model, AAE_Model, VAE_Model, AE_Model
|
||||
|
||||
arguments = ArgumentParser()
|
||||
arguments.add_argument('--data', default='output')
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def load_and_viz(path_like_element):
|
||||
# Define Loop to search for models and folder with visualizations
|
||||
splitpath = path_like_element.split(os.sep)
|
||||
base_dir = os.path.join(*splitpath[:4])
|
||||
model = globals()[splitpath[2]]
|
||||
print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}')
|
||||
pretrained_model = model.load_from_metrics(
|
||||
weights_path=path_like_element,
|
||||
tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
|
||||
on_gpu=True if torch.cuda.is_available() else False,
|
||||
# map_location=None
|
||||
)
|
||||
|
||||
# Init model and freeze its weights ( for faster inference)
|
||||
pretrained_model = pretrained_model.to(device)
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
|
||||
dataIndex = 0
|
||||
|
||||
datasets = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
|
||||
dataset = datasets.datasets[dataIndex]
|
||||
# ToDO: use dataloader for iteration instead! - dataloader = DataLoader(dataset, )
|
||||
|
||||
maps = MapContainer(os.path.join(os.pardir, 'data', 'validation'))
|
||||
base_map = maps.datasets[dataIndex]
|
||||
|
||||
p = Printer(pretrained_model)
|
||||
p.print_trajec_on_basemap(dataset, base_map, save=os.path.join(base_dir, f'{base_map.name}_movement.png'),
|
||||
color_by_movement=True)
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = arguments.parse_args()
|
||||
search_for_weights(load_and_viz, args.data, file_type='movement')
|
Reference in New Issue
Block a user