diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py index 49e1ebc..45ce200 100644 --- a/point_toolset/point_io.py +++ b/point_toolset/point_io.py @@ -6,19 +6,21 @@ class BatchToData(object): def __init__(self): super(BatchToData, self).__init__() - def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor): + def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, + batch_y_l: torch.Tensor, batch_y_c: torch.Tensor): # Convert to torch_geometric.data.Data type # data = data.transpose(1, 2).contiguous() batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3) x = batch_x.reshape(batch_size * num_points, -1) pos = batch_pos.reshape(batch_size * num_points, -1) - batch_y = batch_y.reshape(batch_size * num_points) + batch_y_l = batch_y_l.reshape(batch_size * num_points) + batch_y_c = batch_y_c.reshape(batch_size * num_points) batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i batch = batch.view(-1) data = Data() - data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y + data.x, data.pos, data.batch, data.yl, data.yc = x, pos, batch, batch_y_l, batch_y_c return data diff --git a/utils/tools.py b/utils/tools.py index 68f5a21..6e291f8 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -3,20 +3,17 @@ import shelve from pathlib import Path import numpy as np - -from utils.project_config import GlobalVar +import torch +import random -def to_one_hot(idx_array): - one_hot = np.zeros((idx_array.size, len(GlobalVar.classes))) +def to_one_hot(idx_array, max_classes): + one_hot = np.zeros((idx_array.size, max_classes)) one_hot[np.arange(idx_array.size), idx_array] = 1 return one_hot def fix_all_random_seeds(config_obj): - import numpy as np - import torch - import random np.random.seed(config_obj.main.seed) torch.manual_seed(config_obj.main.seed) random.seed(config_obj.main.seed) @@ -39,4 +36,4 @@ def load_from_shelve(file_path, key): def check_path(file_path): assert isinstance(file_path, Path) - assert str(file_path).endswith('.pik') \ No newline at end of file + assert str(file_path).endswith('.pik')