import torch
from torch.distributions import Categorical
from marl_factory_grid.algorithms.marl.iac import LoopIAC
from marl_factory_grid.algorithms.marl.base_ac import nms
from marl_factory_grid.algorithms.marl.memory import MARLActorCriticMemory


class LoopSEAC(LoopIAC):
    def __init__(self, cfg):
        super(LoopSEAC, self).__init__(cfg)

    def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
        obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
        outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]

        with torch.inference_mode(True):
            true_action_logp = torch.stack([
                torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
                .gather(index=actions[ag_i, 1:, None], dim=-1)
                for ag_i, out in enumerate(outputs)
            ], 0).squeeze()

        losses = []

        for ag_i, out in enumerate(outputs):
            logits = out[nms.LOGITS][:, :-1]  # last one only needed for v_{t+1}
            critic = out[nms.CRITIC]

            entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
            advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)

            # policy loss
            log_ap = torch.log_softmax(logits, -1)
            log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()

            # importance weights
            iw = (log_ap - true_action_logp).exp().detach()  # importance_weights

            a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)

            value_loss = (iw*advantages.pow(2)).mean(-1)  # n_agent

            # weighted loss
            loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
            losses.append(loss)

        return losses

    def learn(self, tms: MARLActorCriticMemory, **kwargs):
        losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
        for ag_i, loss in enumerate(losses):
            self.optimizer[ag_i].zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
            self.optimizer[ag_i].step()