CNN Classifier

This commit is contained in:
Si11ium
2020-02-21 09:44:09 +01:00
parent 537e5371c9
commit 7b3f781d19
12 changed files with 247 additions and 109 deletions

18
datasets/utils.py Normal file
View File

@ -0,0 +1,18 @@
from typing import Union
from torch.utils.data import Dataset, ConcatDataset
from datasets.paired_dataset import TrajPairDataset
class DatasetMapping(Dataset):
def __init__(self, dataset: Union[TrajPairDataset, ConcatDataset, Dataset], mapping):
self._dataset = dataset
self._mapping = mapping
def __len__(self):
return self._mapping.shape[0]
def __getitem__(self, item):
return self._dataset[self._mapping[item]]