Grid Clusters.
This commit is contained in:
@ -3,20 +3,17 @@ import shelve
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from utils.project_config import GlobalVar
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
def to_one_hot(idx_array):
|
||||
one_hot = np.zeros((idx_array.size, len(GlobalVar.classes)))
|
||||
def to_one_hot(idx_array, max_classes):
|
||||
one_hot = np.zeros((idx_array.size, max_classes))
|
||||
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||
return one_hot
|
||||
|
||||
|
||||
def fix_all_random_seeds(config_obj):
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
np.random.seed(config_obj.main.seed)
|
||||
torch.manual_seed(config_obj.main.seed)
|
||||
random.seed(config_obj.main.seed)
|
||||
@ -39,4 +36,4 @@ def load_from_shelve(file_path, key):
|
||||
|
||||
def check_path(file_path):
|
||||
assert isinstance(file_path, Path)
|
||||
assert str(file_path).endswith('.pik')
|
||||
assert str(file_path).endswith('.pik')
|
||||
|
Reference in New Issue
Block a user