diff --git a/dataset/shapenet.py b/dataset/shapenet.py index 96d1799..2999cc3 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -144,7 +144,7 @@ class CustomShapeNet(InMemoryDataset): y = torch.as_tensor(y_all, dtype=torch.int) # This is where you define the keys - attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6]) + attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6]) if self.collate_per_element: data = Data(**attr_dict) else: @@ -193,20 +193,20 @@ class ShapeNetPartSegDataset(Dataset): # Resample to fixed number of points try: npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0] - choice = np.random.choice(data.pos.shape[0], npoints, replace=False) + choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True) except ValueError: choice = [] - # pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice] - pos, labels = data.pos[choice, :], data.y[choice] + pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice] + # pos, labels = data.pos[choice, :], data.y[choice] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] sample = { - 'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6) + 'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6) 'labels': labels, # torch.Tensor (n,) - # 'pos': pos, # torch.Tensor (n, 3) - # 'normals': normals # torch.Tensor (n, 3) + 'pos': pos, # torch.Tensor (n, 3) + 'normals': normals # torch.Tensor (n, 3) } return sample diff --git a/model/pointnet2_part_seg.py b/model/pointnet2_part_seg.py index 8c66ebf..f23e5b2 100644 --- a/model/pointnet2_part_seg.py +++ b/model/pointnet2_part_seg.py @@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.data.data import Data from torch_scatter import scatter_add, scatter_max -GLOBAL_POINT_FEATURES = 3 +GLOBAL_POINT_FEATURES = 6 class PointNet2SAModule(torch.nn.Module): def __init__(self, sample_radio, radius, max_num_neighbors, mlp):