Can now be trained with normals
This commit is contained in:
parent
a501dcd6b0
commit
92117328ad
@ -24,13 +24,15 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])}
|
modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])}
|
||||||
|
|
||||||
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
|
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
|
||||||
pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False):
|
pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False,
|
||||||
|
with_normals=False):
|
||||||
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
|
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
|
||||||
assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations'
|
assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations'
|
||||||
|
|
||||||
#Set the Dataset Parameters
|
#Set the Dataset Parameters
|
||||||
self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within
|
self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within
|
||||||
self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh
|
self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh
|
||||||
|
self.with_normals = with_normals
|
||||||
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
|
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
|
||||||
self.data, self.slices = self._load_dataset()
|
self.data, self.slices = self._load_dataset()
|
||||||
print("Initialized")
|
print("Initialized")
|
||||||
@ -143,8 +145,10 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
y_all = [-1] * points.shape[0]
|
y_all = [-1] * points.shape[0]
|
||||||
|
|
||||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||||
|
####################################
|
||||||
# This is where you define the keys
|
# This is where you define the keys
|
||||||
attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6])
|
attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
|
||||||
|
####################################
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
data = Data(**attr_dict)
|
data = Data(**attr_dict)
|
||||||
else:
|
else:
|
||||||
@ -197,16 +201,13 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
choice = []
|
choice = []
|
||||||
|
|
||||||
pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
pos, labels = data.pos[choice, :], data.y[choice]
|
||||||
# pos, labels = data.pos[choice, :], data.y[choice]
|
|
||||||
|
|
||||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6)
|
'points': pos, # torch.Tensor (n, 6)
|
||||||
'labels': labels, # torch.Tensor (n,)
|
'labels': labels # torch.Tensor (n,)
|
||||||
'pos': pos, # torch.Tensor (n, 3)
|
|
||||||
'normals': normals # torch.Tensor (n, 3)
|
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
8
main.py
8
main.py
@ -38,6 +38,7 @@ parser.add_argument('--batch_size', type=int, default=8, help='input batch size'
|
|||||||
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||||
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||||
|
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
|
||||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||||
help='whether a single pointcloud has variations '
|
help='whether a single pointcloud has variations '
|
||||||
@ -69,7 +70,7 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
TransTransform = GT.RandomTranslate(trans_max_distance)
|
TransTransform = GT.RandomTranslate(trans_max_distance)
|
||||||
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
|
train_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||||
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||||
|
|
||||||
params = dict(root_dir=opt.dataset,
|
params = dict(root_dir=opt.dataset,
|
||||||
@ -78,7 +79,8 @@ if __name__ == '__main__':
|
|||||||
npoints=opt.npoints,
|
npoints=opt.npoints,
|
||||||
labels_within=opt.labels_within,
|
labels_within=opt.labels_within,
|
||||||
has_variations=opt.has_variations,
|
has_variations=opt.has_variations,
|
||||||
headers=opt.headers
|
headers=opt.headers,
|
||||||
|
with_normals=opt.with_normals
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = ShapeNetPartSegDataset(mode='train', **params)
|
dataset = ShapeNetPartSegDataset(mode='train', **params)
|
||||||
@ -105,7 +107,7 @@ if __name__ == '__main__':
|
|||||||
dtype = torch.float
|
dtype = torch.float
|
||||||
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
||||||
|
|
||||||
net = PointNet2PartSegmentNet(num_classes)
|
net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
|
||||||
|
|
||||||
if opt.model != '':
|
if opt.model != '':
|
||||||
net.load_state_dict(torch.load(opt.model))
|
net.load_state_dict(torch.load(opt.model))
|
||||||
|
@ -8,15 +8,15 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|||||||
from torch_geometric.data.data import Data
|
from torch_geometric.data.data import Data
|
||||||
from torch_scatter import scatter_add, scatter_max
|
from torch_scatter import scatter_add, scatter_max
|
||||||
|
|
||||||
GLOBAL_POINT_FEATURES = 6
|
|
||||||
|
|
||||||
class PointNet2SAModule(torch.nn.Module):
|
class PointNet2SAModule(torch.nn.Module):
|
||||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
def __init__(self, sample_radio, radius, max_num_neighbors, mlp, features=3):
|
||||||
super(PointNet2SAModule, self).__init__()
|
super(PointNet2SAModule, self).__init__()
|
||||||
self.sample_ratio = sample_radio
|
self.sample_ratio = sample_radio
|
||||||
self.radius = radius
|
self.radius = radius
|
||||||
self.max_num_neighbors = max_num_neighbors
|
self.max_num_neighbors = max_num_neighbors
|
||||||
self.point_conv = PointConv(mlp)
|
self.point_conv = PointConv(mlp)
|
||||||
|
self.features=features
|
||||||
|
|
||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
x, pos, batch = data
|
x, pos, batch = data
|
||||||
@ -40,9 +40,10 @@ class PointNet2GlobalSAModule(torch.nn.Module):
|
|||||||
One group with all input points, can be viewed as a simple PointNet module.
|
One group with all input points, can be viewed as a simple PointNet module.
|
||||||
It also return the only one output point(set as origin point).
|
It also return the only one output point(set as origin point).
|
||||||
'''
|
'''
|
||||||
def __init__(self, mlp):
|
def __init__(self, mlp, features=3):
|
||||||
super(PointNet2GlobalSAModule, self).__init__()
|
super(PointNet2GlobalSAModule, self).__init__()
|
||||||
self.mlp = mlp
|
self.mlp = mlp
|
||||||
|
self.features = features
|
||||||
|
|
||||||
def forward(self, data):
|
def forward(self, data):
|
||||||
x, pos, batch = data
|
x, pos, batch = data
|
||||||
@ -52,7 +53,7 @@ class PointNet2GlobalSAModule(torch.nn.Module):
|
|||||||
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
||||||
|
|
||||||
batch_size = x1.shape[0]
|
batch_size = x1.shape[0]
|
||||||
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin
|
pos1 = x1.new_zeros((batch_size, self.features)) # set the output point as origin
|
||||||
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
||||||
|
|
||||||
return x1, pos1, batch1
|
return x1, pos1, batch1
|
||||||
@ -158,44 +159,47 @@ class PointNet2PartSegmentNet(torch.nn.Module):
|
|||||||
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
|
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
|
||||||
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
|
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
|
||||||
'''
|
'''
|
||||||
def __init__(self, num_classes):
|
def __init__(self, num_classes, with_normals=False):
|
||||||
super(PointNet2PartSegmentNet, self).__init__()
|
super(PointNet2PartSegmentNet, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
self.features = 3 if not with_normals else 6
|
||||||
|
|
||||||
# SA1
|
# SA1
|
||||||
sa1_sample_ratio = 0.5
|
sa1_sample_ratio = 0.5
|
||||||
sa1_radius = 0.2
|
sa1_radius = 0.2
|
||||||
sa1_max_num_neighbours = 64
|
sa1_max_num_neighbours = 64
|
||||||
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128])
|
sa1_mlp = make_mlp(self.features, [64, 64, 128])
|
||||||
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
|
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp,
|
||||||
|
features=self.features)
|
||||||
|
|
||||||
# SA2
|
# SA2
|
||||||
sa2_sample_ratio = 0.25
|
sa2_sample_ratio = 0.25
|
||||||
sa2_radius = 0.4
|
sa2_radius = 0.4
|
||||||
sa2_max_num_neighbours = 64
|
sa2_max_num_neighbours = 64
|
||||||
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256])
|
sa2_mlp = make_mlp(128+self.features, [128, 128, 256])
|
||||||
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
|
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp,
|
||||||
|
features=self.features)
|
||||||
|
|
||||||
# SA3
|
# SA3
|
||||||
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024])
|
sa3_mlp = make_mlp(256+self.features, [256, 512, 1024])
|
||||||
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
|
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp, self.features)
|
||||||
|
|
||||||
##
|
##
|
||||||
knn_num = GLOBAL_POINT_FEATURES
|
knn_num = self.features
|
||||||
|
|
||||||
# FP3, reverse of sa3
|
# FP3, reverse of sa3
|
||||||
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
||||||
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256])
|
fp3_mlp = make_mlp(1024+256+self.features, [256, 256])
|
||||||
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
||||||
|
|
||||||
# FP2, reverse of sa2
|
# FP2, reverse of sa2
|
||||||
fp2_knn_num = knn_num
|
fp2_knn_num = knn_num
|
||||||
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128])
|
fp2_mlp = make_mlp(256+128+self.features, [256, 128])
|
||||||
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
||||||
|
|
||||||
# FP1, reverse of sa1
|
# FP1, reverse of sa1
|
||||||
fp1_knn_num = knn_num
|
fp1_knn_num = knn_num
|
||||||
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128])
|
fp1_mlp = make_mlp(128+self.features, [128, 128, 128])
|
||||||
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
||||||
|
|
||||||
self.fc1 = Lin(128, 128)
|
self.fc1 = Lin(128, 128)
|
||||||
@ -252,11 +256,12 @@ class PointNet2PartSegmentNet(torch.nn.Module):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
num_classes = 10
|
num_classes = 10
|
||||||
net = PointNet2PartSegmentNet(num_classes)
|
num_features = 6
|
||||||
|
net = PointNet2PartSegmentNet(num_classes, features=num_features)
|
||||||
|
|
||||||
#
|
#
|
||||||
print('Test dense input ..')
|
print('Test dense input ..')
|
||||||
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points)
|
data1 = torch.rand((2, num_features, 1024)) # (batch_size, 3, num_points)
|
||||||
print('data1: ', data1.shape)
|
print('data1: ', data1.shape)
|
||||||
|
|
||||||
out1 = net(data1)
|
out1 = net(data1)
|
||||||
@ -272,7 +277,7 @@ if __name__ == '__main__':
|
|||||||
data_batch = Data()
|
data_batch = Data()
|
||||||
|
|
||||||
# data_batch.x = None
|
# data_batch.x = None
|
||||||
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0)
|
data_batch.pos = torch.cat([torch.rand(pos_num1, num_features), torch.rand(pos_num2, num_features)], dim=0)
|
||||||
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
||||||
|
|
||||||
return data_batch
|
return data_batch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user