18 lines
464 B
Python
18 lines
464 B
Python
from typing import Union
|
|
|
|
from torch.utils.data import Dataset, ConcatDataset
|
|
|
|
from datasets.paired_dataset import TrajPairDataset
|
|
|
|
|
|
class DatasetMapping(Dataset):
|
|
|
|
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
|
|
self._dataset = dataset
|
|
self._mapping = mapping
|
|
|
|
def __len__(self):
|
|
return self._mapping.shape[0]
|
|
|
|
def __getitem__(self, item):
|
|
return self._dataset[self._mapping[item]] |