2020-02-21 09:44:09 +01:00

18 lines
464 B
Python

from typing import Union
from torch.utils.data import Dataset, ConcatDataset
from datasets.paired_dataset import TrajPairDataset
class DatasetMapping(Dataset):
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
self._dataset = dataset
self._mapping = mapping
def __len__(self):
return self._mapping.shape[0]
def __getitem__(self, item):
return self._dataset[self._mapping[item]]