import torch
import torch.nn.functional as F
from algorithms.q_learner import QLearner


class MQLearner(QLearner):
    # Munchhausen Q-Learning
    def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs):
        super(MQLearner, self).__init__(*args, **kwargs)
        assert self.n_agents == 1, 'M-DQN currently only supports single agent training'
        self.temperature = temperature
        self.alpha = alpha
        self.clip0 = clip_l0

    def tau_ln_pi(self, qs):
        # computes log(softmax(qs/temperature))
        # Custom log-sum-exp trick from page 18 to compute the log-policy terms
        v_k = qs.max(-1)[0].unsqueeze(-1)
        advantage = qs - v_k
        logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1)
        tau_ln_pi = advantage - self.temperature * logsum
        return tau_ln_pi

    def train(self):
        if len(self.buffer) < self.batch_size: return
        for _ in range(self.n_grad_steps):

            experience = self.buffer.sample(self.batch_size, cer=self.train_every[-1])

            with torch.no_grad():
                q_target_next = self.target_q_net(experience.next_observation)
                tau_log_pi_next = self.tau_ln_pi(q_target_next)

                q_k_targets = self.target_q_net(experience.observation)
                log_pi = self.tau_ln_pi(q_k_targets)

                pi_target = F.softmax(q_target_next / self.temperature, dim=-1)
                q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1)

                munchausen_addon = log_pi.gather(-1, experience.action)

                munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0))

                # Compute Q targets for current states
                m_q_target = munchausen_reward + q_target

            # Get expected Q values from local model
            q_k = self.q_net(experience.observation)
            pred_q = q_k.gather(-1, experience.action)

            # Compute loss
            loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
            self._backprop_loss(loss)

from tqdm import trange
class MQICMLearner(MQLearner):
    def __init__(self, *args, icm, **kwargs):
        super(MQICMLearner, self).__init__(*args, **kwargs)
        self.icm = icm
        self.icm_optimizer = torch.optim.Adam(self.icm.parameters())

    def on_all_done(self):
        for b in trange(50000):
            batch = self.buffer.sample(128, 0)
            s0, s1, a = batch.observation,  batch.next_observation, batch.action
            loss = self.icm(s0, s1, a.squeeze())['loss']
            self.icm_optimizer.zero_grad()
            loss.backward()
            self.icm_optimizer.step()
            if b%100 == 0:
                print(loss.item())