Grid Clusters.

This commit is contained in:
Si11ium
2020-06-07 16:47:51 +02:00
parent 5987efb169
commit 2acf91335f
2 changed files with 10 additions and 11 deletions

View File

@ -3,20 +3,17 @@ import shelve
from pathlib import Path
import numpy as np
from utils.project_config import GlobalVar
import torch
import random
def to_one_hot(idx_array):
one_hot = np.zeros((idx_array.size, len(GlobalVar.classes)))
def to_one_hot(idx_array, max_classes):
one_hot = np.zeros((idx_array.size, max_classes))
one_hot[np.arange(idx_array.size), idx_array] = 1
return one_hot
def fix_all_random_seeds(config_obj):
import numpy as np
import torch
import random
np.random.seed(config_obj.main.seed)
torch.manual_seed(config_obj.main.seed)
random.seed(config_obj.main.seed)
@ -39,4 +36,4 @@ def load_from_shelve(file_path, key):
def check_path(file_path):
assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik')
assert str(file_path).endswith('.pik')