import sys

from pathlib import Path
from collections import deque
from itertools import product

import numpy as np
import pygame
from typing import Tuple, Union
import time

from marl_factory_grid.utils.utility_classes import RenderEntity

AGENT: str = 'agent'
STATE_IDLE: str = 'idle'
STATE_MOVE: str = 'move'
STATE_VALID: str = 'valid'
STATE_INVALID: str = 'invalid'
STATE_COLLISION: str = 'agent_collision'
BLANK: str = 'blank'
DOOR: str = 'door'
OPACITY: str = 'opacity'
SCALE: str = 'scale'


class Renderer:
    BG_COLOR = (178, 190, 195)         # (99, 110, 114)
    WHITE = (223, 230, 233)            # (200, 200, 200)
    AGENT_VIEW_COLOR = (9, 132, 227)
    ASSETS = Path(__file__).parent.parent

    def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
                 lvl_padded_shape: Union[Tuple[int, int], None] = None,
                 cell_size: int = 40, fps: int = 7, factor: float = 0.9,
                 grid_lines: bool = True, view_radius: int = 2):
        """
        TODO


        :return:
        """
        # TODO: Customn_assets paths
        self.grid_h, self.grid_w = lvl_shape
        self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
        self.cell_size = cell_size
        self.fps = fps
        self.grid_lines = grid_lines
        self.view_radius = view_radius
        pygame.init()
        self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
        self.screen = pygame.display.set_mode(self.screen_size)
        self.clock = pygame.time.Clock()
        assets = list(self.ASSETS.rglob('*.png'))
        self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
        self.fill_bg()

        now = time.time()
        self.font = pygame.font.Font(None, 20)
        self.font.set_bold(True)
        print('Loading System font with pygame.font.Font took', time.time() - now)

    def fill_bg(self):
        self.screen.fill(Renderer.BG_COLOR)
        if self.grid_lines:
            w, h = self.screen_size
            for x in range(0, w, self.cell_size):
                for y in range(0, h, self.cell_size):
                    rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
                    pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)

    def blit_params(self, entity):
        offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
                             (self.lvl_padded_shape[1] - self.grid_w) // 2

        r, c = entity.pos
        r, c = r - offset_r, c-offset_c

        img = self.assets[entity.name.lower()]
        if entity.value_operation == OPACITY:
            img.set_alpha(255*entity.value)
        elif entity.value_operation == SCALE:
            re = img.get_rect()
            img = pygame.transform.smoothscale(
                img, (int(entity.value*re.width), int(entity.value*re.height))
            )
        o = self.cell_size//2
        r_, c_ = r*self.cell_size + o, c*self.cell_size + o
        rect = img.get_rect()
        rect.centerx, rect.centery = c_, r_
        return dict(source=img, dest=rect)

    def load_asset(self, path, factor=1.0):
        s = int(factor*self.cell_size)
        asset = pygame.image.load(path).convert_alpha()
        asset = pygame.transform.smoothscale(asset, (s, s))
        return asset

    def visibility_rects(self, bp, view):
        rects = []
        for i, j in product(range(-self.view_radius, self.view_radius+1),
                            range(-self.view_radius, self.view_radius+1)):
            if view is not None:
                if bool(view[self.view_radius+j, self.view_radius+i]):
                    visibility_rect = bp['dest'].copy()
                    visibility_rect.centerx += i*self.cell_size
                    visibility_rect.centery += j*self.cell_size
                    shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
                    pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
                    shape_surf.set_alpha(64)
                    rects.append(dict(source=shape_surf, dest=visibility_rect))
        return rects

    def render(self, entities):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()
        self.fill_bg()
        # First all others
        blits = deque(self.blit_params(x) for x in entities if not x.name.lower() == AGENT)
        # Then Agents, so that agents are rendered on top.
        for agent in (x for x in entities if x.name.lower() == AGENT):
            agent_blit = self.blit_params(agent)
            if self.view_radius > 0:
                vis_rects = self.visibility_rects(agent_blit, agent.aux)
                blits.extendleft(vis_rects)
            if agent.state != BLANK:
                state_blit = self.blit_params(
                    RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
                )
                textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
                text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
                                                           agent_blit['dest'].center[1]))
                blits += [agent_blit, state_blit, text_blit]

        for blit in blits:
            self.screen.blit(**blit)

        pygame.display.flip()
        self.clock.tick(self.fps)
        rgb_obs = pygame.surfarray.array3d(self.screen)
        return np.transpose(rgb_obs, (2, 0, 1))
        # return torch.from_numpy(rgb_obs).permute(2, 0, 1)


if __name__ == '__main__':
    renderer = Renderer(fps=2, cell_size=40)
    for pos_i in range(15):
        entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
        renderer.render([entity_1])