from typing import List

from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.results import TickResult, DoneResult
from marl_factory_grid.environment import constants as c
from . import constants as M


class MoveMaintainers(Rule):

    def __init__(self):
        """
        This rule is responsible for moving the maintainers at every step of the environment.
        """
        super().__init__()

    def tick_step(self, state) -> List[TickResult]:
        move_results = []
        for maintainer in state[M.MAINTAINERS]:
            result = maintainer.tick(state)
            move_results.append(result)
        return move_results


class DoneAtMaintainerCollision(Rule):

    def __init__(self):
        """
        When active, this rule stops the environment after a maintainer reports a collision with another entity.
        """
        super().__init__()

    def on_check_done(self, state) -> List[DoneResult]:
        agents = list(state[c.AGENT].values())
        m_pos = state[M.MAINTAINERS].positions
        done_results = []
        for agent in agents:
            if agent.pos in m_pos:
                done_results.append(DoneResult(entity=agent, validity=c.VALID, identifier=self.name,
                                               reward=M.MAINTAINER_COLLISION_REWARD))
        return done_results