import pickle
import warnings
from typing import Union, List
from os import PathLike
from pathlib import Path
import time

import pandas as pd

from stable_baselines3.common.callbacks import CallbackList
from stable_baselines3.common.vec_env import SubprocVecEnv

from environments.factory.double_task_factory import DoubleTaskFactory, ItemProperties
from environments.factory.simple_factory import DirtProperties, SimpleFactory
from environments.helpers import IGNORED_DF_COLUMNS
from environments.logging.monitor import MonitorCallback
from environments.logging.plotting import prepare_plot
from environments.logging.recorder import RecorderCallback
from environments.utility_classes import MovementProperties

warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)


def combine_runs(run_path: Union[str, PathLike]):
    run_path = Path(run_path)
    df_list = list()
    for run, monitor_file in enumerate(run_path.rglob('monitor_*.pick')):
        with monitor_file.open('rb') as f:
            monitor_df = pickle.load(f)

        monitor_df['run'] = run
        monitor_df = monitor_df.fillna(0)
        df_list.append(monitor_df)

    df = pd.concat(df_list,  ignore_index=True)
    df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run'})
    columns = [col for col in df.columns if col not in IGNORED_DF_COLUMNS]

    roll_n = 50
    skip_n = 40

    non_overlapp_window = df.groupby(['Run', 'Episode']).rolling(roll_n, min_periods=1).mean()

    df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run'],
                                                                value_vars=columns, var_name="Measurement",
                                                                value_name="Score")
    df_melted = df_melted[df_melted['Episode'] % skip_n == 0]

    prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
    print('Plotting done.')


def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List[str]]):
    run_path = Path(run_path)
    df_list = list()
    parameter = [parameter] if isinstance(parameter, str) else parameter
    for path in run_path.iterdir():
        if path.is_dir() and str(run_identifier) in path.name:
            for run, monitor_file in enumerate(path.rglob('monitor_*.pick')):
                with monitor_file.open('rb') as f:
                    monitor_df = pickle.load(f)

                monitor_df['run'] = run
                monitor_df['model'] = path.name.split('_')[0]
                monitor_df = monitor_df.fillna(0)
                df_list.append(monitor_df)

    df = pd.concat(df_list, ignore_index=True)
    df = df.fillna(0).rename(columns={'episode': 'Episode', 'run': 'Run', 'model': 'Model'})
    columns = [col for col in df.columns if col in parameter]

    roll_n = 40
    skip_n = 20

    non_overlapp_window = df.groupby(['Model', 'Run', 'Episode']).rolling(roll_n, min_periods=1).mean()

    df_melted = non_overlapp_window[columns].reset_index().melt(id_vars=['Episode', 'Run', 'Model'],
                                                                value_vars=columns, var_name="Measurement",
                                                                value_name="Score")
    df_melted = df_melted[df_melted['Episode'] % skip_n == 0]

    style = 'Measurement' if len(columns) > 1 else None
    prepare_plot(run_path / f'{run_identifier}_compare_{parameter}.png', df_melted, hue='Model', style=style)
    print('Plotting done.')


def make_env(env_kwargs_dict):

    def _init():
        with SimpleFactory(**env_kwargs_dict) as init_env:
            return init_env

    return _init


if __name__ == '__main__':

    # combine_runs(Path('debug_out') / 'A2C_1630314192')
    # exit()

    # compare_runs(Path('debug_out'), 1623052687, ['step_reward'])
    # exit()

    from stable_baselines3 import PPO, DQN, A2C
    from algorithms.reg_dqn import RegDQN
    # from sb3_contrib import QRDQN

    dirt_props = DirtProperties(clean_amount=2, gain_amount=0.1, max_global_amount=20,
                                max_local_amount=1, spawn_frequency=3, max_spawn_ratio=0.05,
                                dirt_smear_amount=0.0, agent_can_interact=True)
    item_props = ItemProperties(n_items=5, agent_can_interact=True)
    move_props = MovementProperties(allow_diagonal_movement=False,
                                    allow_square_movement=True,
                                    allow_no_op=False)
    train_steps = 1e6
    time_stamp = int(time.time())

    out_path = None

    for modeL_type in [A2C, PPO, DQN]:  # ,RegDQN, QRDQN]:
        for seed in range(3):
            env_kwargs = dict(n_agents=1,
                              # with_dirt=True,
                              # item_properties=item_props,
                              dirt_properties=dirt_props,
                              movement_properties=move_props,
                              pomdp_r=2, max_steps=400, parse_doors=True,
                              level_name='simple', frames_to_stack=6,
                              omit_agent_in_obs=True, combin_agent_obs=True, record_episodes=False,
                              cast_shadows=True, doors_have_area=False, env_seed=seed, verbose=False,
                              )

            # env = make_env(env_kwargs)()
            env = SubprocVecEnv([make_env(env_kwargs) for _ in range(12)], start_method="spawn")

            if modeL_type.__name__ in ["PPO", "A2C"]:
                kwargs = dict(ent_coef=0.01)
            elif modeL_type.__name__ in ["RegDQN", "DQN", "QRDQN"]:
                kwargs = dict(buffer_size=50000,
                              learning_starts=64,
                              batch_size=64,
                              target_update_interval=5000,
                              exploration_fraction=0.25,
                              exploration_final_eps=0.025)
            else:
                raise NameError(f'The model "{modeL_type.__name__}" has the wrong name.')
            model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)

            out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'

            # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
            identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
            out_path /= identifier

            callbacks = CallbackList(
                [MonitorCallback(filepath=out_path / f'monitor_{identifier}.pick', plotting=False),
                 RecorderCallback(filepath=out_path / f'recorder_{identifier}.json', occupation_map=False,
                                  trajectory_map=False
                                  )]
            )

            model.learn(total_timesteps=int(train_steps), callback=callbacks)

            save_path = out_path / f'model_{identifier}.zip'
            save_path.parent.mkdir(parents=True, exist_ok=True)
            model.save(save_path)
            env.env_method('save_params', out_path.parent / f'env_{model.__class__.__name__}_{time_stamp}.yaml')
            print("Model Trained and saved")
        print("Model Group Done.. Plotting...")

        if out_path:
            combine_runs(out_path.parent)
    print("All Models Done... Evaluating")
    if out_path:
        compare_runs(Path('debug_out'), time_stamp, 'step_reward')