from typing import Union from torch.utils.data import Dataset, ConcatDataset from datasets.paired_dataset import TrajPairDataset class DatasetMapping(Dataset): def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping): self._dataset = dataset self._mapping = mapping def __len__(self): return self._mapping.shape[0] def __getitem__(self, item): return self._dataset[self._mapping[item]]