ml_lib/utils/data_util.py

41 lines
993 B
Python

import torch
from torch.utils.data import Dataset
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
class ReMapDataset(Dataset):
@property
def sample_shape(self):
return list(self[0][0].shape)
def __init__(self, ds, mapping):
super(ReMapDataset, self).__init__()
# here is a mapping from this index to the mother ds index
self.mapping = mapping
self.ds = ds
def __getitem__(self, index):
return self.ds[self.mapping[index]]
def __len__(self):
return self.mapping.shape[0]
@classmethod
def do_train_vali_split(cls, ds, split_fold=0.1):
indices = torch.randperm(len(ds))
valid_size = int(len(ds) * split_fold)
train_mapping = indices[valid_size:]
valid_mapping = indices[:valid_size]
train = cls(ds, train_mapping)
valid = cls(ds, valid_mapping)
return train, valid