ae_toolbox_torch/viz/print_movement_in_map.py
2021-02-01 09:59:56 +01:00

26 lines
853 B
Python

from argparse import ArgumentParser
import os
from torch import device
from torch.cuda import is_available
from dataset import DataContainer
from viz.utils import Printer, MapContainer
available_device = device('cuda' if is_available() else 'cpu')
arguments = ArgumentParser()
arguments.add_argument('--data', default='output')
if __name__ == '__main__':
args = arguments.parse_args()
maps = MapContainer(os.path.join(os.pardir, 'data', 'validation'))
base_map = maps.datasets[0]
datasets = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(available_device)
dataset = datasets.datasets[0]
p = Printer(None)
p.print_trajec_on_basemap(dataset, base_map, save=os.path.join(f'{base_map.name}_movement.png'),
color_by_movement=True, n=20, clustering='fastdtw', show=True)