41 lines
993 B
Python
41 lines
993 B
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def chunks(l, n):
|
|
"""Yield successive n-sized chunks from l."""
|
|
for i in range(0, len(l), n):
|
|
yield l[i:i + n]
|
|
|
|
|
|
class ReMapDataset(Dataset):
|
|
@property
|
|
def sample_shape(self):
|
|
return list(self[0][0].shape)
|
|
|
|
def __init__(self, ds, mapping):
|
|
super(ReMapDataset, self).__init__()
|
|
# here is a mapping from this index to the mother ds index
|
|
self.mapping = mapping
|
|
self.ds = ds
|
|
|
|
def __getitem__(self, index):
|
|
return self.ds[self.mapping[index]]
|
|
|
|
def __len__(self):
|
|
return self.mapping.shape[0]
|
|
|
|
@classmethod
|
|
def do_train_vali_split(cls, ds, split_fold=0.1):
|
|
|
|
indices = torch.randperm(len(ds))
|
|
|
|
valid_size = int(len(ds) * split_fold)
|
|
|
|
train_mapping = indices[valid_size:]
|
|
valid_mapping = indices[:valid_size]
|
|
|
|
train = cls(ds, train_mapping)
|
|
valid = cls(ds, valid_mapping)
|
|
|
|
return train, valid |