diff --git a/dataset/shapenet.py b/dataset/shapenet.py index 2999cc3..cf224fc 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -24,13 +24,15 @@ class CustomShapeNet(InMemoryDataset): modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])} def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None, - pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False): + pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False, + with_normals=False): assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}' assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations' #Set the Dataset Parameters self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh + self.with_normals = with_normals super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter) self.data, self.slices = self._load_dataset() print("Initialized") @@ -143,8 +145,10 @@ class CustomShapeNet(InMemoryDataset): y_all = [-1] * points.shape[0] y = torch.as_tensor(y_all, dtype=torch.int) + #################################### # This is where you define the keys - attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6]) + attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6]) + #################################### if self.collate_per_element: data = Data(**attr_dict) else: @@ -197,16 +201,13 @@ class ShapeNetPartSegDataset(Dataset): except ValueError: choice = [] - pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice] - # pos, labels = data.pos[choice, :], data.y[choice] + pos, labels = data.pos[choice, :], data.y[choice] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] sample = { - 'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6) - 'labels': labels, # torch.Tensor (n,) - 'pos': pos, # torch.Tensor (n, 3) - 'normals': normals # torch.Tensor (n, 3) + 'points': pos, # torch.Tensor (n, 6) + 'labels': labels # torch.Tensor (n,) } return sample diff --git a/main.py b/main.py index be53a28..7782767 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,7 @@ parser.add_argument('--batch_size', type=int, default=8, help='input batch size' parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number') parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') +parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') parser.add_argument('--has_variations', type=strtobool, default=False, help='whether a single pointcloud has variations ' @@ -69,7 +70,7 @@ if __name__ == '__main__': ) TransTransform = GT.RandomTranslate(trans_max_distance) - train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform]) + train_transform = GT.Compose([GT.NormalizeScale(), ]) test_transform = GT.Compose([GT.NormalizeScale(), ]) params = dict(root_dir=opt.dataset, @@ -78,7 +79,8 @@ if __name__ == '__main__': npoints=opt.npoints, labels_within=opt.labels_within, has_variations=opt.has_variations, - headers=opt.headers + headers=opt.headers, + with_normals=opt.with_normals ) dataset = ShapeNetPartSegDataset(mode='train', **params) @@ -105,7 +107,7 @@ if __name__ == '__main__': dtype = torch.float print('cudnn.enabled: ', torch.backends.cudnn.enabled) - net = PointNet2PartSegmentNet(num_classes) + net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals) if opt.model != '': net.load_state_dict(torch.load(opt.model)) diff --git a/model/pointnet2_part_seg.py b/model/pointnet2_part_seg.py index f23e5b2..271578a 100644 --- a/model/pointnet2_part_seg.py +++ b/model/pointnet2_part_seg.py @@ -8,15 +8,15 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.data.data import Data from torch_scatter import scatter_add, scatter_max -GLOBAL_POINT_FEATURES = 6 class PointNet2SAModule(torch.nn.Module): - def __init__(self, sample_radio, radius, max_num_neighbors, mlp): + def __init__(self, sample_radio, radius, max_num_neighbors, mlp, features=3): super(PointNet2SAModule, self).__init__() self.sample_ratio = sample_radio self.radius = radius self.max_num_neighbors = max_num_neighbors self.point_conv = PointConv(mlp) + self.features=features def forward(self, data): x, pos, batch = data @@ -40,9 +40,10 @@ class PointNet2GlobalSAModule(torch.nn.Module): One group with all input points, can be viewed as a simple PointNet module. It also return the only one output point(set as origin point). ''' - def __init__(self, mlp): + def __init__(self, mlp, features=3): super(PointNet2GlobalSAModule, self).__init__() self.mlp = mlp + self.features = features def forward(self, data): x, pos, batch = data @@ -52,7 +53,7 @@ class PointNet2GlobalSAModule(torch.nn.Module): x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1) batch_size = x1.shape[0] - pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin + pos1 = x1.new_zeros((batch_size, self.features)) # set the output point as origin batch1 = torch.arange(batch_size).to(batch.device, batch.dtype) return x1, pos1, batch1 @@ -158,44 +159,47 @@ class PointNet2PartSegmentNet(torch.nn.Module): - https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py - https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py ''' - def __init__(self, num_classes): + def __init__(self, num_classes, with_normals=False): super(PointNet2PartSegmentNet, self).__init__() self.num_classes = num_classes + self.features = 3 if not with_normals else 6 # SA1 sa1_sample_ratio = 0.5 sa1_radius = 0.2 sa1_max_num_neighbours = 64 - sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128]) - self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp) + sa1_mlp = make_mlp(self.features, [64, 64, 128]) + self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp, + features=self.features) # SA2 sa2_sample_ratio = 0.25 sa2_radius = 0.4 sa2_max_num_neighbours = 64 - sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256]) - self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp) + sa2_mlp = make_mlp(128+self.features, [128, 128, 256]) + self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp, + features=self.features) # SA3 - sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024]) - self.sa3_module = PointNet2GlobalSAModule(sa3_mlp) + sa3_mlp = make_mlp(256+self.features, [256, 512, 1024]) + self.sa3_module = PointNet2GlobalSAModule(sa3_mlp, self.features) ## - knn_num = GLOBAL_POINT_FEATURES + knn_num = self.features # FP3, reverse of sa3 fp3_knn_num = 1 # After global sa module, there is only one point in point cloud - fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256]) + fp3_mlp = make_mlp(1024+256+self.features, [256, 256]) self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp) # FP2, reverse of sa2 fp2_knn_num = knn_num - fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128]) + fp2_mlp = make_mlp(256+128+self.features, [256, 128]) self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp) # FP1, reverse of sa1 fp1_knn_num = knn_num - fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128]) + fp1_mlp = make_mlp(128+self.features, [128, 128, 128]) self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp) self.fc1 = Lin(128, 128) @@ -252,11 +256,12 @@ class PointNet2PartSegmentNet(torch.nn.Module): if __name__ == '__main__': num_classes = 10 - net = PointNet2PartSegmentNet(num_classes) + num_features = 6 + net = PointNet2PartSegmentNet(num_classes, features=num_features) # print('Test dense input ..') - data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points) + data1 = torch.rand((2, num_features, 1024)) # (batch_size, 3, num_points) print('data1: ', data1.shape) out1 = net(data1) @@ -272,7 +277,7 @@ if __name__ == '__main__': data_batch = Data() # data_batch.x = None - data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0) + data_batch.pos = torch.cat([torch.rand(pos_num1, num_features), torch.rand(pos_num2, num_features)], dim=0) data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)]) return data_batch