from pathlib import Path

import torch
from torch_geometric.data import Data
from tqdm import tqdm

from torch.utils.data import DataLoader

# Dataset and Dataloaders
# =============================================================================

# Transforms
from ml_lib.point_toolset.point_io import BatchToData
from ml_lib.utils.model_io import SavedLightningModels


# Datasets
from datasets.shapenet import ShapeNetPartSegDataset
from utils.project_config import ThisConfig

raise BrokenPipeError('There are Imports that need to be fixed first!!!!')

def prepare_dataloader(config_obj):
    dataset = ShapeNetPartSegDataset(config_obj.data.root, mode=GlobalVar.data_split.test,
                                setting=GlobalVar.settings[config_obj.model.type])
    # noinspection PyTypeChecker
    return DataLoader(dataset, batch_size=config_obj.train.batch_size,
                      num_workers=config_obj.data.worker, shuffle=False)


def restore_logger_and_model(log_dir):
    model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1)
    model = model.restore()
    if torch.cuda.is_available():
        model.cuda()
    else:
        model.cpu()
    return model


if __name__ == '__main__':
    outpath = Path('output')
    model_path = Path('/home/steffen/projects/point_to_primitive/output/P2G/PG_9f7ac027e3359fa5f5e5bcd32044a167/version_69')
    config_filename = 'config.ini'
    inference_out = 'manual_test_out.csv'

    config = ThisConfig()
    config.read_file((Path(model_path) / config_filename).open('r'))
    test_dataloader = prepare_dataloader(config)

    loaded_model = restore_logger_and_model(model_path)
    loaded_model.eval()

    with (model_path / inference_out).open(mode='w') as outfile:
        outfile.write(f'{",".join(FullCloudsDataset.headers[:6])},class,cluster\n')
        batch_to_data = BatchToData()
        for batch_pos_x_n_y_c in tqdm(test_dataloader, total=len(test_dataloader)):
            data = batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
            y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
            y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy()
            y_sec = -1
            try:
                y_sec = torch.argmax(y.grid_out, dim=-1).squeeze().cpu().numpy()
            except AttributeError:
                pass
            try:
                y_sec = torch.argmax(y.prim_out, dim=-1).squeeze().cpu().numpy()
            except AttributeError:
                pass
            for row in range(data.num_nodes):
                outfile.write(f'{",".join(map(str, data.pos[row].tolist()))},' +
                              f'{",".join(map(str, data.x[row].tolist()))},' +
                              f'{y_primary[row]},{y_sec[row]}\n')
    print('Done')