# Imports
# =============================================================================
from pathlib import Path

from tqdm import tqdm
import warnings

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from ml_lib.modules.utils import LightningBaseModule
from ml_lib.utils.logging import Logger

# Project Specific Logger SubClasses
from util.config import MConfig


warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)


def fix_all_random_seeds(config_obj):
    import numpy as np
    import torch
    import random
    np.random.seed(config.main.seed)
    torch.manual_seed(config.main.seed)
    random.seed(config.main.seed)


def run_lightning_loop(config_obj):

    # Logging
    # ================================================================================
    # Logger
    with Logger(config_obj) as logger:
        # Callbacks
        # =============================================================================
        # Checkpoint Saving
        checkpoint_callback = ModelCheckpoint(
            monitor='uar_score',
            filepath=str(logger.log_dir / 'ckpt_weights'),
            verbose=False,
            save_top_k=5,
        )

        # Early Stopping
        # TODO: For This to work, set a validation step and End Eval and Score
        early_stopping_callback = EarlyStopping(
            monitor='uar_score',
            min_delta=0.01,
            patience=10,
        )

        # Trainer
        # =============================================================================
        trainer = Trainer(max_epochs=config_obj.train.epochs,
                          show_progress_bar=True,
                          weights_save_path=logger.log_dir,
                          gpus=[0] if torch.cuda.is_available() else None,
                          check_val_every_n_epoch=10,
                          # num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
                          # row_log_interval=(model.n_train_batches * 0.1),  # TODO: Better Value / Setting
                          # log_save_interval=(model.n_train_batches * 0.2),  # TODO: Better Value / Setting
                          checkpoint_callback=checkpoint_callback,
                          logger=logger,
                          fast_dev_run=config_obj.main.debug,
                          early_stop_callback=None
                          )

        # Model
        # =============================================================================
        # Build and Init its Weights
        model: LightningBaseModule = config_obj.build_and_init_model()
        # Log paramters
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        logger.log_text('n_parameters', pytorch_total_params)

        # Train It
        if config_obj.model.type.lower() != 'ensemble':
            trainer.fit(model)

            # Save the last state & all parameters
            trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
            model.save_to_disk(logger.log_dir)

        # Evaluate It
        if config_obj.main.eval:
            with torch.no_grad():
                model.eval()
                if torch.cuda.is_available():
                    model.cuda()
                outputs = []
                from tqdm import tqdm
                for idx, batch in enumerate(tqdm(model.val_dataloader()[0])):
                    batch_x, label = batch
                    batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
                    label = label.to(device='cuda' if model.on_gpu else 'cpu')
                    outputs.append(
                        model.validation_step((batch_x, label), idx, 1)
                    )
                    summary_dict = model.validation_epoch_end([outputs])
                print(summary_dict['log']['uar_score'])

                # trainer.test()
                outpath = Path(config_obj.train.outpath)
                model_type = config_obj.model.type
                parameters = logger.name
                version = f'version_{logger.version}'
                inference_out = f'{parameters}_test_out.csv'

                from main_inference import prepare_dataloader
                test_dataloader = prepare_dataloader(config)

                with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
                    outfile.write(f'file_name,prediction\n')

                    from tqdm import tqdm
                    for batch in tqdm(test_dataloader, total=len(test_dataloader)):
                        batch_x, file_name = batch
                        batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu')
                        y = model(batch_x).main_out
                        prediction = (y.squeeze() >= 0.5).int().item()
                        import variables as V
                        prediction = 'clear' if prediction == V.CLEAR else 'mask'
                        outfile.write(f'{file_name},{prediction}\n')
    return model


if __name__ == "__main__":

    from _paramters import main_arg_parser

    config = MConfig.read_argparser(main_arg_parser)
    fix_all_random_seeds(config)
    trained_model = run_lightning_loop(config)