from typing import Union

from torch.utils.data import Dataset, ConcatDataset

from datasets.paired_dataset import TrajPairDataset


class DatasetMapping(Dataset):

    def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
        self._dataset = dataset
        self._mapping = mapping

    def __len__(self):
        return self._mapping.shape[0]

    def __getitem__(self, item):
        return self._dataset[self._mapping[item]]