import copy
import random

from pathlib import Path
from typing import Union

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.ticker import ScalarFormatter
from tqdm import tqdm
from matplotlib import pyplot as plt
from torch.nn import functional as F
from tabulate import tabulate

from functionalities_test import test_for_fixpoints, is_zero_fixpoint, is_divergent, is_identity_function
from network import Net
from visualization import plot_loss, bar_chart_fixpoints, plot_3d_soup, line_chart_fixpoints


def prng():
    return random.random()


class SoupRobustnessExperiment:

    def __init__(self, population_size, net_i_size, net_h_size, net_o_size, learning_rate, attack_chance,
                 train_nets, ST_steps, epochs, log_step_size, directory: Union[str, Path]):
        super().__init__()
        self.population_size = population_size

        self.net_input_size = net_i_size
        self.net_hidden_size = net_h_size
        self.net_out_size = net_o_size
        self.net_learning_rate = learning_rate
        self.attack_chance = attack_chance
        self.train_nets = train_nets
        # self.SA_steps = SA_steps
        self.ST_steps = ST_steps
        self.epochs = epochs
        self.log_step_size = log_step_size

        self.loss_history = []

        self.fixpoint_counters = {
            "identity_func": 0,
            "divergent": 0,
            "fix_zero": 0,
            "fix_weak": 0,
            "fix_sec": 0,
            "other_func": 0
        }
        # <self.fixpoint_counters_history> is used for keeping track of the amount of fixpoints in %
        self.fixpoint_counters_history = []
        self.id_functions = []

        self.directory = Path(directory)
        self.directory.mkdir(parents=True, exist_ok=True)

        self.population = []
        self.populate_environment()

        self.evolve()
        self.fixpoint_percentage()
        self.weights_evolution_3d_experiment()
        self.count_fixpoints()
        self.visualize_loss()

        self.time_to_vergence, self.time_as_fixpoint = self.test_robustness()

    def populate_environment(self):
        loop_population_size = tqdm(range(self.population_size))
        for i in tqdm(range(self.population_size)):
            loop_population_size.set_description("Populating soup experiment %s" % i)

            net_name = f"soup_network_{i}"
            net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
            self.population.append(net)

    def evolve(self):
        """ Evolving consists of attacking & self-training. """

        loop_epochs = tqdm(range(self.epochs))
        for i in loop_epochs:
            loop_epochs.set_description("Evolving soup %s" % i)

            # A network attacking another network with a given percentage
            if random.randint(1, 100) <= self.attack_chance:
                random_net1, random_net2 = random.sample(range(self.population_size), 2)
                random_net1 = self.population[random_net1]
                random_net2 = self.population[random_net2]
                print(f"\n Attack: {random_net1.name} -> {random_net2.name}")
                random_net1.attack(random_net2)

            #  Self-training each network in the population
            for j in range(self.population_size):
                net = self.population[j]

                for _ in range(self.ST_steps):
                    net.self_train(1, self.log_step_size, self.net_learning_rate)

            # Testing for fixpoints after each batch of ST steps to see relevant data
            if i % self.ST_steps == 0:
                test_for_fixpoints(self.fixpoint_counters, self.population)
                fixpoints_percentage = round(self.fixpoint_counters["identity_func"] / self.population_size, 1)
                self.fixpoint_counters_history.append(fixpoints_percentage)

            # Resetting the fixpoint counter. Last iteration not to be reset -
            #  it is important for the bar_chart_fixpoints().
            if i < self.epochs:
                self.reset_fixpoint_counters()

    def test_robustness(self, print_it=True, noise_levels=10, seeds=10):
        # assert (len(self.id_functions) == 1 and seeds > 1) or (len(self.id_functions) > 1 and seeds == 1)
        is_synthetic = True if len(self.id_functions) > 1 and seeds == 1 else False
        avg_time_to_vergence = [[0 for _ in range(noise_levels)] for _ in
                                range(seeds if is_synthetic else len(self.id_functions))]
        avg_time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in
                                range(seeds if is_synthetic else len(self.id_functions))]
        row_headers = []
        data_pos = 0
        # This checks wether to use synthetic setting with multiple seeds
        #   or multi network settings with a singlee seed

        df = pd.DataFrame(columns=['seed', 'noise_level', 'application_step', 'absolute_loss'])
        for i, fixpoint in enumerate(self.id_functions):  # 1 / n
            row_headers.append(fixpoint.name)
            for seed in range(seeds):  # n / 1
                for noise_level in range(noise_levels):
                    self_application_steps = 1
                    clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
                                f"{fixpoint.name}_clone_noise10e-{noise_level}")
                    clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
                    clone = clone.apply_noise(pow(10, -noise_level))

                    while not is_zero_fixpoint(clone) and not is_divergent(clone):
                        if is_identity_function(clone):
                            avg_time_as_fixpoint[i][noise_level] += 1

                        # -> before
                        clone_weight_pre_application = clone.input_weight_matrix()
                        target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)

                        clone.self_application(1, self.log_step_size)
                        avg_time_to_vergence[i][noise_level] += 1
                        # -> after
                        clone_weight_post_application = clone.input_weight_matrix()
                        target_data_post_application = clone.create_target_weights(clone_weight_post_application)

                        absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()

                        setting = i if is_synthetic else seed

                        df.loc[data_pos] = [setting, noise_level, self_application_steps, absolute_loss]
                        data_pos += 1
                        self_application_steps += 1

        # calculate the average:
        df = df.replace([np.inf, -np.inf], np.nan)
        df = df.dropna()
        # sns.set(rc={'figure.figsize': (10, 50)})
        sns.set_theme(style="ticks")
        bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
                         col='noise_level', col_wrap=3, showfliers=False)

        directory = Path('output') / 'robustness'
        filename = f"absolute_loss_perapplication_boxplot_grid.png"
        filepath = directory / filename

        plt.savefig(str(filepath))

        if print_it:
            col_headers = [str(f"10-{d}") for d in range(noise_levels)]

            print(f"\nAppplications steps until divergence / zero: ")
            print(tabulate(avg_time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))

            print(f"\nTime as fixpoint: ")
            print(tabulate(avg_time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))

        return avg_time_as_fixpoint, avg_time_to_vergence

    def weights_evolution_3d_experiment(self):
        exp_name = f"soup_{self.population_size}_nets_{self.ST_steps}_training_{self.epochs}_epochs"
        return plot_3d_soup(self.population, exp_name, self.directory)

    def count_fixpoints(self):
        self.id_functions = test_for_fixpoints(self.fixpoint_counters, self.population)
        exp_details = f"Evolution steps: {self.epochs} epochs"
        bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory, self.net_learning_rate,
                            exp_details)

    def fixpoint_percentage(self):
        runs = self.epochs / self.ST_steps
        SA_steps = None
        line_chart_fixpoints(self.fixpoint_counters_history, runs, self.ST_steps, SA_steps, self.directory,
                             self.population_size)

    def visualize_loss(self):
        for i in range(len(self.population)):
            net_loss_history = self.population[i].loss_history
            self.loss_history.append(net_loss_history)

        plot_loss(self.loss_history, self.directory)

    def reset_fixpoint_counters(self):
        self.fixpoint_counters = {
            "identity_func": 0,
            "divergent": 0,
            "fix_zero": 0,
            "fix_weak": 0,
            "fix_sec": 0,
            "other_func": 0
        }


if __name__ == "__main__":
    NET_INPUT_SIZE = 4
    NET_OUT_SIZE = 1

    soup_epochs = 100
    soup_log_step_size = 5
    soup_ST_steps = 20
    # soup_SA_steps = 10

    # Define number of networks & their architecture
    soup_population_size = 4
    soup_net_hidden_size = 2
    soup_net_learning_rate = 0.04

    # soup_attack_chance in %
    soup_attack_chance = 10

    # not used yet: soup_train_nets has 3 possible values "no", "before_SA", "after_SA".
    soup_train_nets = "no"
    soup_name_hash = random.getrandbits(32)
    soup_synthetic = True

    print(f"Running the robustness comparison experiment:")
    SoupRobustnessExperiment(
        population_size=soup_population_size,
        net_i_size=NET_INPUT_SIZE,
        net_h_size=soup_net_hidden_size,
        net_o_size=NET_OUT_SIZE,
        learning_rate=soup_net_learning_rate,
        attack_chance=soup_attack_chance,
        train_nets=soup_train_nets,
        ST_steps=soup_ST_steps,
        epochs=soup_epochs,
        log_step_size=soup_log_step_size,
        directory=Path('output') / 'robustness' / f'{soup_name_hash}'
    )