from typing import List

import numpy as np

from environments.factory.base.objects import Floor, Agent
from environments.factory.base.registers import EntityCollection, BoundEnvObjCollection, ObjectCollection
from environments.factory.additional.item.item_entities import Item, DropOffLocation


class ItemRegister(EntityCollection):

    _accepted_objects = Item

    def spawn_items(self, tiles: List[Floor]):
        items = [Item(tile, self) for tile in tiles]
        self.add_additional_items(items)

    def despawn_items(self, items: List[Item]):
        items = [items] if isinstance(items, Item) else items
        for item in items:
            del self[item]


class Inventory(BoundEnvObjCollection):

    @property
    def name(self):
        return f'{self.__class__.__name__}({self._bound_entity.name})'

    def __init__(self, agent: Agent, capacity: int, *args, **kwargs):
        super(Inventory, self).__init__(agent, *args,  is_blocking_light=False, can_be_shadowed=False,  **kwargs)
        self.capacity = capacity

    def as_array(self):
        if self._array is None:
            self._array = np.zeros((1, *self._shape))
        return super(Inventory, self).as_array()

    def summarize_states(self, **kwargs):
        attr_dict = {key: str(val) for key, val in self.__dict__.items() if not key.startswith('_') and key != 'data'}
        attr_dict.update(dict(items={key: val.summarize_state(**kwargs) for key, val in self.items()}))
        attr_dict.update(dict(name=self.name))
        return attr_dict

    def pop(self):
        item_to_pop = self[0]
        self.delete_env_object(item_to_pop)
        return item_to_pop


class Inventories(ObjectCollection):

    _accepted_objects = Inventory
    is_blocking_light = False
    can_be_shadowed = False

    def __init__(self, obs_shape, *args, **kwargs):
        super(Inventories, self).__init__(*args, is_per_agent=True, individual_slices=True, **kwargs)
        self._obs_shape = obs_shape

    def as_array(self):
        return np.stack([inventory.as_array() for inv_idx, inventory in enumerate(self)])

    def spawn_inventories(self, agents, capacity):
        inventories = [self._accepted_objects(agent, capacity, self._obs_shape)
                       for _, agent in enumerate(agents)]
        self.add_additional_items(inventories)

    def idx_by_entity(self, entity):
        try:
            return next((idx for idx, inv in enumerate(self) if inv.belongs_to_entity(entity)))
        except StopIteration:
            return None

    def by_entity(self, entity):
        try:
            return next((inv for inv in self if inv.belongs_to_entity(entity)))
        except StopIteration:
            return None

    def summarize_states(self, **kwargs):
        return {key: val.summarize_states(**kwargs) for key, val in self.items()}


class DropOffLocations(EntityCollection):

    _accepted_objects = DropOffLocation