eval running - offline logger implemented -> Test it!
This commit is contained in:
41
utils/data_util.py
Normal file
41
utils/data_util.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def chunks(l, n):
|
||||
"""Yield successive n-sized chunks from l."""
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
class ReMapDataset(Dataset):
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return list(self[0][0].shape)
|
||||
|
||||
def __init__(self, ds, mapping):
|
||||
super(ReMapDataset, self).__init__()
|
||||
# here is a mapping from this index to the mother ds index
|
||||
self.mapping = mapping
|
||||
self.ds = ds
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.ds[self.mapping[index]]
|
||||
|
||||
def __len__(self):
|
||||
return self.mapping.shape[0]
|
||||
|
||||
@classmethod
|
||||
def do_train_vali_split(cls, ds, split_fold=0.1):
|
||||
|
||||
indices = torch.randperm(len(ds))
|
||||
|
||||
valid_size = int(len(ds) * split_fold)
|
||||
|
||||
train_mapping = indices[valid_size:]
|
||||
valid_mapping = indices[:valid_size]
|
||||
|
||||
train = cls(ds, train_mapping)
|
||||
valid = cls(ds, valid_mapping)
|
||||
|
||||
return train, valid
|
Reference in New Issue
Block a user