Grid Clusters.

This commit is contained in:
Si11ium 2020-06-07 16:47:51 +02:00
parent 5987efb169
commit 2acf91335f
2 changed files with 10 additions and 11 deletions

View File

@ -6,19 +6,21 @@ class BatchToData(object):
def __init__(self): def __init__(self):
super(BatchToData, self).__init__() super(BatchToData, self).__init__()
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor): def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor,
batch_y_l: torch.Tensor, batch_y_c: torch.Tensor):
# Convert to torch_geometric.data.Data type # Convert to torch_geometric.data.Data type
# data = data.transpose(1, 2).contiguous() # data = data.transpose(1, 2).contiguous()
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3) batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
x = batch_x.reshape(batch_size * num_points, -1) x = batch_x.reshape(batch_size * num_points, -1)
pos = batch_pos.reshape(batch_size * num_points, -1) pos = batch_pos.reshape(batch_size * num_points, -1)
batch_y = batch_y.reshape(batch_size * num_points) batch_y_l = batch_y_l.reshape(batch_size * num_points)
batch_y_c = batch_y_c.reshape(batch_size * num_points)
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
for i in range(batch_size): for i in range(batch_size):
batch[i] = i batch[i] = i
batch = batch.view(-1) batch = batch.view(-1)
data = Data() data = Data()
data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y data.x, data.pos, data.batch, data.yl, data.yc = x, pos, batch, batch_y_l, batch_y_c
return data return data

View File

@ -3,20 +3,17 @@ import shelve
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import torch
from utils.project_config import GlobalVar import random
def to_one_hot(idx_array): def to_one_hot(idx_array, max_classes):
one_hot = np.zeros((idx_array.size, len(GlobalVar.classes))) one_hot = np.zeros((idx_array.size, max_classes))
one_hot[np.arange(idx_array.size), idx_array] = 1 one_hot[np.arange(idx_array.size), idx_array] = 1
return one_hot return one_hot
def fix_all_random_seeds(config_obj): def fix_all_random_seeds(config_obj):
import numpy as np
import torch
import random
np.random.seed(config_obj.main.seed) np.random.seed(config_obj.main.seed)
torch.manual_seed(config_obj.main.seed) torch.manual_seed(config_obj.main.seed)
random.seed(config_obj.main.seed) random.seed(config_obj.main.seed)
@ -39,4 +36,4 @@ def load_from_shelve(file_path, key):
def check_path(file_path): def check_path(file_path):
assert isinstance(file_path, Path) assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik') assert str(file_path).endswith('.pik')