from argparse import ArgumentParser
import os

from dataset import DataContainer
from viz.utils import MotionAnalyser, Printer, MapContainer, search_for_weights
import torch

arguments = ArgumentParser()
arguments.add_argument('--data', default='output')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from viz.utils import *
from run_models import *

arguments = ArgumentParser()
arguments.add_argument('--data', default='output')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_and_viz(path_like_element):
    # Define Loop to search for models and folder with visualizations
    splitpath = path_like_element.split(os.sep)
    base_dir = os.path.join(*splitpath[:4])
    model = globals()[splitpath[2]]
    print(f'... loading model named: "{model.name}" from timestamp: {splitpath[3]}')
    pretrained_model = model.load_from_metrics(
        weights_path=path_like_element,
        tags_csv=os.path.join(base_dir, 'default', 'version_0', 'meta_tags.csv'),
        on_gpu=True if torch.cuda.is_available() else False,
        # map_location=None
    )

    # Init model and freeze its weights ( for faster inference)
    pretrained_model = pretrained_model.to(device)
    pretrained_model.eval()
    pretrained_model.freeze()

    dataIndex = 0

    datasets = DataContainer(os.path.join(os.pardir, 'data', 'validation'), 9, 6).to(device)
    dataset = datasets.datasets[dataIndex]
    # ToDO: use dataloader for iteration instead!  -  dataloader = DataLoader(dataset, )

    maps = MapContainer(os.path.join(os.pardir, 'data', 'validation'))
    base_map = maps.datasets[dataIndex]

    p = Printer(pretrained_model)
    p.print_trajec_on_basemap(dataset, base_map, save=os.path.join(base_dir, f'{base_map.name}_map.png'))


if __name__ == '__main__':
    args = arguments.parse_args()
    search_for_weights(load_and_viz, args.data, file_type='map')