eval running - offline logger implemented -> Test it!
This commit is contained in:
41
utils/data_util.py
Normal file
41
utils/data_util.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
def chunks(l, n):
|
||||
"""Yield successive n-sized chunks from l."""
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
class ReMapDataset(Dataset):
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return list(self[0][0].shape)
|
||||
|
||||
def __init__(self, ds, mapping):
|
||||
super(ReMapDataset, self).__init__()
|
||||
# here is a mapping from this index to the mother ds index
|
||||
self.mapping = mapping
|
||||
self.ds = ds
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.ds[self.mapping[index]]
|
||||
|
||||
def __len__(self):
|
||||
return self.mapping.shape[0]
|
||||
|
||||
@classmethod
|
||||
def do_train_vali_split(cls, ds, split_fold=0.1):
|
||||
|
||||
indices = torch.randperm(len(ds))
|
||||
|
||||
valid_size = int(len(ds) * split_fold)
|
||||
|
||||
train_mapping = indices[valid_size:]
|
||||
valid_mapping = indices[:valid_size]
|
||||
|
||||
train = cls(ds, train_mapping)
|
||||
valid = cls(ds, valid_mapping)
|
||||
|
||||
return train, valid
|
||||
@@ -1,3 +1,6 @@
|
||||
import argparse
|
||||
from typing import Union, Dict, Optional, Any
|
||||
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -2,6 +2,16 @@ import pickle
|
||||
import shelve
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.project_config import GlobalVar
|
||||
|
||||
|
||||
def to_one_hot(idx_array):
|
||||
one_hot = np.zeros((idx_array.size, len(GlobalVar.classes)))
|
||||
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||
return one_hot
|
||||
|
||||
|
||||
def fix_all_random_seeds(config_obj):
|
||||
import numpy as np
|
||||
|
||||
Reference in New Issue
Block a user