Grid Clusters.
This commit is contained in:
parent
5987efb169
commit
2acf91335f
@ -6,19 +6,21 @@ class BatchToData(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(BatchToData, self).__init__()
|
super(BatchToData, self).__init__()
|
||||||
|
|
||||||
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor):
|
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor,
|
||||||
|
batch_y_l: torch.Tensor, batch_y_c: torch.Tensor):
|
||||||
# Convert to torch_geometric.data.Data type
|
# Convert to torch_geometric.data.Data type
|
||||||
# data = data.transpose(1, 2).contiguous()
|
# data = data.transpose(1, 2).contiguous()
|
||||||
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
||||||
|
|
||||||
x = batch_x.reshape(batch_size * num_points, -1)
|
x = batch_x.reshape(batch_size * num_points, -1)
|
||||||
pos = batch_pos.reshape(batch_size * num_points, -1)
|
pos = batch_pos.reshape(batch_size * num_points, -1)
|
||||||
batch_y = batch_y.reshape(batch_size * num_points)
|
batch_y_l = batch_y_l.reshape(batch_size * num_points)
|
||||||
|
batch_y_c = batch_y_c.reshape(batch_size * num_points)
|
||||||
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
batch[i] = i
|
batch[i] = i
|
||||||
batch = batch.view(-1)
|
batch = batch.view(-1)
|
||||||
|
|
||||||
data = Data()
|
data = Data()
|
||||||
data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y
|
data.x, data.pos, data.batch, data.yl, data.yc = x, pos, batch, batch_y_l, batch_y_c
|
||||||
return data
|
return data
|
||||||
|
@ -3,20 +3,17 @@ import shelve
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from utils.project_config import GlobalVar
|
import random
|
||||||
|
|
||||||
|
|
||||||
def to_one_hot(idx_array):
|
def to_one_hot(idx_array, max_classes):
|
||||||
one_hot = np.zeros((idx_array.size, len(GlobalVar.classes)))
|
one_hot = np.zeros((idx_array.size, max_classes))
|
||||||
one_hot[np.arange(idx_array.size), idx_array] = 1
|
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def fix_all_random_seeds(config_obj):
|
def fix_all_random_seeds(config_obj):
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import random
|
|
||||||
np.random.seed(config_obj.main.seed)
|
np.random.seed(config_obj.main.seed)
|
||||||
torch.manual_seed(config_obj.main.seed)
|
torch.manual_seed(config_obj.main.seed)
|
||||||
random.seed(config_obj.main.seed)
|
random.seed(config_obj.main.seed)
|
||||||
@ -39,4 +36,4 @@ def load_from_shelve(file_path, key):
|
|||||||
|
|
||||||
def check_path(file_path):
|
def check_path(file_path):
|
||||||
assert isinstance(file_path, Path)
|
assert isinstance(file_path, Path)
|
||||||
assert str(file_path).endswith('.pik')
|
assert str(file_path).endswith('.pik')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user