from argparse import ArgumentParser import os from torch import device from torch.cuda import is_available from dataset import DataContainer from viz.utils import Printer, MapContainer available_device = device('cuda' if is_available() else 'cpu') arguments = ArgumentParser() arguments.add_argument('--data', default='output') if __name__ == '__main__': args = arguments.parse_args() maps = MapContainer(os.path.join(os.pardir, 'data', 'validation')) base_map = maps.datasets[0] datasets = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(available_device) dataset = datasets.datasets[0] p = Printer(None) p.print_trajec_on_basemap(dataset, base_map, save=os.path.join(f'{base_map.name}_movement.png'), color_by_movement=True, n=20, clustering='fastdtw', show=True)