Can now be trained with normals

This commit is contained in:
Si11ium 2019-08-09 13:32:55 +02:00
parent a501dcd6b0
commit 92117328ad
3 changed files with 37 additions and 29 deletions

View File

@ -24,13 +24,15 @@ class CustomShapeNet(InMemoryDataset):
modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])} modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])}
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None, def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False): pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False,
with_normals=False):
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}' assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations' assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations'
#Set the Dataset Parameters #Set the Dataset Parameters
self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within
self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh
self.with_normals = with_normals
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter) super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
self.data, self.slices = self._load_dataset() self.data, self.slices = self._load_dataset()
print("Initialized") print("Initialized")
@ -143,8 +145,10 @@ class CustomShapeNet(InMemoryDataset):
y_all = [-1] * points.shape[0] y_all = [-1] * points.shape[0]
y = torch.as_tensor(y_all, dtype=torch.int) y = torch.as_tensor(y_all, dtype=torch.int)
####################################
# This is where you define the keys # This is where you define the keys
attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6]) attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
####################################
if self.collate_per_element: if self.collate_per_element:
data = Data(**attr_dict) data = Data(**attr_dict)
else: else:
@ -197,16 +201,13 @@ class ShapeNetPartSegDataset(Dataset):
except ValueError: except ValueError:
choice = [] choice = []
pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice] pos, labels = data.pos[choice, :], data.y[choice]
# pos, labels = data.pos[choice, :], data.y[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
sample = { sample = {
'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6) 'points': pos, # torch.Tensor (n, 6)
'labels': labels, # torch.Tensor (n,) 'labels': labels # torch.Tensor (n,)
'pos': pos, # torch.Tensor (n, 3)
'normals': normals # torch.Tensor (n, 3)
} }
return sample return sample

View File

@ -38,6 +38,7 @@ parser.add_argument('--batch_size', type=int, default=8, help='input batch size'
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number') parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers') parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False, parser.add_argument('--has_variations', type=strtobool, default=False,
help='whether a single pointcloud has variations ' help='whether a single pointcloud has variations '
@ -69,7 +70,7 @@ if __name__ == '__main__':
) )
TransTransform = GT.RandomTranslate(trans_max_distance) TransTransform = GT.RandomTranslate(trans_max_distance)
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform]) train_transform = GT.Compose([GT.NormalizeScale(), ])
test_transform = GT.Compose([GT.NormalizeScale(), ]) test_transform = GT.Compose([GT.NormalizeScale(), ])
params = dict(root_dir=opt.dataset, params = dict(root_dir=opt.dataset,
@ -78,7 +79,8 @@ if __name__ == '__main__':
npoints=opt.npoints, npoints=opt.npoints,
labels_within=opt.labels_within, labels_within=opt.labels_within,
has_variations=opt.has_variations, has_variations=opt.has_variations,
headers=opt.headers headers=opt.headers,
with_normals=opt.with_normals
) )
dataset = ShapeNetPartSegDataset(mode='train', **params) dataset = ShapeNetPartSegDataset(mode='train', **params)
@ -105,7 +107,7 @@ if __name__ == '__main__':
dtype = torch.float dtype = torch.float
print('cudnn.enabled: ', torch.backends.cudnn.enabled) print('cudnn.enabled: ', torch.backends.cudnn.enabled)
net = PointNet2PartSegmentNet(num_classes) net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
if opt.model != '': if opt.model != '':
net.load_state_dict(torch.load(opt.model)) net.load_state_dict(torch.load(opt.model))

View File

@ -8,15 +8,15 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max from torch_scatter import scatter_add, scatter_max
GLOBAL_POINT_FEATURES = 6
class PointNet2SAModule(torch.nn.Module): class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp): def __init__(self, sample_radio, radius, max_num_neighbors, mlp, features=3):
super(PointNet2SAModule, self).__init__() super(PointNet2SAModule, self).__init__()
self.sample_ratio = sample_radio self.sample_ratio = sample_radio
self.radius = radius self.radius = radius
self.max_num_neighbors = max_num_neighbors self.max_num_neighbors = max_num_neighbors
self.point_conv = PointConv(mlp) self.point_conv = PointConv(mlp)
self.features=features
def forward(self, data): def forward(self, data):
x, pos, batch = data x, pos, batch = data
@ -40,9 +40,10 @@ class PointNet2GlobalSAModule(torch.nn.Module):
One group with all input points, can be viewed as a simple PointNet module. One group with all input points, can be viewed as a simple PointNet module.
It also return the only one output point(set as origin point). It also return the only one output point(set as origin point).
''' '''
def __init__(self, mlp): def __init__(self, mlp, features=3):
super(PointNet2GlobalSAModule, self).__init__() super(PointNet2GlobalSAModule, self).__init__()
self.mlp = mlp self.mlp = mlp
self.features = features
def forward(self, data): def forward(self, data):
x, pos, batch = data x, pos, batch = data
@ -52,7 +53,7 @@ class PointNet2GlobalSAModule(torch.nn.Module):
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1) x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
batch_size = x1.shape[0] batch_size = x1.shape[0]
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin pos1 = x1.new_zeros((batch_size, self.features)) # set the output point as origin
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype) batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
return x1, pos1, batch1 return x1, pos1, batch1
@ -158,44 +159,47 @@ class PointNet2PartSegmentNet(torch.nn.Module):
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py - https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py - https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
''' '''
def __init__(self, num_classes): def __init__(self, num_classes, with_normals=False):
super(PointNet2PartSegmentNet, self).__init__() super(PointNet2PartSegmentNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.features = 3 if not with_normals else 6
# SA1 # SA1
sa1_sample_ratio = 0.5 sa1_sample_ratio = 0.5
sa1_radius = 0.2 sa1_radius = 0.2
sa1_max_num_neighbours = 64 sa1_max_num_neighbours = 64
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128]) sa1_mlp = make_mlp(self.features, [64, 64, 128])
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp) self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp,
features=self.features)
# SA2 # SA2
sa2_sample_ratio = 0.25 sa2_sample_ratio = 0.25
sa2_radius = 0.4 sa2_radius = 0.4
sa2_max_num_neighbours = 64 sa2_max_num_neighbours = 64
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256]) sa2_mlp = make_mlp(128+self.features, [128, 128, 256])
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp) self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp,
features=self.features)
# SA3 # SA3
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024]) sa3_mlp = make_mlp(256+self.features, [256, 512, 1024])
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp) self.sa3_module = PointNet2GlobalSAModule(sa3_mlp, self.features)
## ##
knn_num = GLOBAL_POINT_FEATURES knn_num = self.features
# FP3, reverse of sa3 # FP3, reverse of sa3
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256]) fp3_mlp = make_mlp(1024+256+self.features, [256, 256])
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp) self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
# FP2, reverse of sa2 # FP2, reverse of sa2
fp2_knn_num = knn_num fp2_knn_num = knn_num
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128]) fp2_mlp = make_mlp(256+128+self.features, [256, 128])
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp) self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
# FP1, reverse of sa1 # FP1, reverse of sa1
fp1_knn_num = knn_num fp1_knn_num = knn_num
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128]) fp1_mlp = make_mlp(128+self.features, [128, 128, 128])
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp) self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
self.fc1 = Lin(128, 128) self.fc1 = Lin(128, 128)
@ -252,11 +256,12 @@ class PointNet2PartSegmentNet(torch.nn.Module):
if __name__ == '__main__': if __name__ == '__main__':
num_classes = 10 num_classes = 10
net = PointNet2PartSegmentNet(num_classes) num_features = 6
net = PointNet2PartSegmentNet(num_classes, features=num_features)
# #
print('Test dense input ..') print('Test dense input ..')
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points) data1 = torch.rand((2, num_features, 1024)) # (batch_size, 3, num_points)
print('data1: ', data1.shape) print('data1: ', data1.shape)
out1 = net(data1) out1 = net(data1)
@ -272,7 +277,7 @@ if __name__ == '__main__':
data_batch = Data() data_batch = Data()
# data_batch.x = None # data_batch.x = None
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0) data_batch.pos = torch.cat([torch.rand(pos_num1, num_features), torch.rand(pos_num2, num_features)], dim=0)
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)]) data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
return data_batch return data_batch