26 lines
853 B
Python
26 lines
853 B
Python
from argparse import ArgumentParser
|
|
import os
|
|
|
|
from torch import device
|
|
from torch.cuda import is_available
|
|
|
|
from dataset import DataContainer
|
|
from viz.utils import Printer, MapContainer
|
|
|
|
available_device = device('cuda' if is_available() else 'cpu')
|
|
|
|
arguments = ArgumentParser()
|
|
arguments.add_argument('--data', default='output')
|
|
|
|
if __name__ == '__main__':
|
|
args = arguments.parse_args()
|
|
|
|
maps = MapContainer(os.path.join(os.pardir, 'data', 'validation'))
|
|
base_map = maps.datasets[0]
|
|
|
|
datasets = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(available_device)
|
|
dataset = datasets.datasets[0]
|
|
|
|
p = Printer(None)
|
|
p.print_trajec_on_basemap(dataset, base_map, save=os.path.join(f'{base_map.name}_movement.png'),
|
|
color_by_movement=True, n=20, clustering='fastdtw', show=True) |