Can now be trained with normals
This commit is contained in:
parent
a501dcd6b0
commit
92117328ad
@ -24,13 +24,15 @@ class CustomShapeNet(InMemoryDataset):
|
||||
modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])}
|
||||
|
||||
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
|
||||
pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False):
|
||||
pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False,
|
||||
with_normals=False):
|
||||
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
|
||||
assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations'
|
||||
|
||||
#Set the Dataset Parameters
|
||||
self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within
|
||||
self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh
|
||||
self.with_normals = with_normals
|
||||
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
|
||||
self.data, self.slices = self._load_dataset()
|
||||
print("Initialized")
|
||||
@ -143,8 +145,10 @@ class CustomShapeNet(InMemoryDataset):
|
||||
y_all = [-1] * points.shape[0]
|
||||
|
||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||
####################################
|
||||
# This is where you define the keys
|
||||
attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6])
|
||||
attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
|
||||
####################################
|
||||
if self.collate_per_element:
|
||||
data = Data(**attr_dict)
|
||||
else:
|
||||
@ -197,16 +201,13 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
except ValueError:
|
||||
choice = []
|
||||
|
||||
pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
||||
# pos, labels = data.pos[choice, :], data.y[choice]
|
||||
pos, labels = data.pos[choice, :], data.y[choice]
|
||||
|
||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||
|
||||
sample = {
|
||||
'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6)
|
||||
'labels': labels, # torch.Tensor (n,)
|
||||
'pos': pos, # torch.Tensor (n, 3)
|
||||
'normals': normals # torch.Tensor (n, 3)
|
||||
'points': pos, # torch.Tensor (n, 6)
|
||||
'labels': labels # torch.Tensor (n,)
|
||||
}
|
||||
return sample
|
||||
|
||||
|
8
main.py
8
main.py
@ -38,6 +38,7 @@ parser.add_argument('--batch_size', type=int, default=8, help='input batch size'
|
||||
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
|
||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||
help='whether a single pointcloud has variations '
|
||||
@ -69,7 +70,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
TransTransform = GT.RandomTranslate(trans_max_distance)
|
||||
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
|
||||
train_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||
|
||||
params = dict(root_dir=opt.dataset,
|
||||
@ -78,7 +79,8 @@ if __name__ == '__main__':
|
||||
npoints=opt.npoints,
|
||||
labels_within=opt.labels_within,
|
||||
has_variations=opt.has_variations,
|
||||
headers=opt.headers
|
||||
headers=opt.headers,
|
||||
with_normals=opt.with_normals
|
||||
)
|
||||
|
||||
dataset = ShapeNetPartSegDataset(mode='train', **params)
|
||||
@ -105,7 +107,7 @@ if __name__ == '__main__':
|
||||
dtype = torch.float
|
||||
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
||||
|
||||
net = PointNet2PartSegmentNet(num_classes)
|
||||
net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
|
||||
|
||||
if opt.model != '':
|
||||
net.load_state_dict(torch.load(opt.model))
|
||||
|
@ -8,15 +8,15 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
||||
from torch_geometric.data.data import Data
|
||||
from torch_scatter import scatter_add, scatter_max
|
||||
|
||||
GLOBAL_POINT_FEATURES = 6
|
||||
|
||||
class PointNet2SAModule(torch.nn.Module):
|
||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp, features=3):
|
||||
super(PointNet2SAModule, self).__init__()
|
||||
self.sample_ratio = sample_radio
|
||||
self.radius = radius
|
||||
self.max_num_neighbors = max_num_neighbors
|
||||
self.point_conv = PointConv(mlp)
|
||||
self.features=features
|
||||
|
||||
def forward(self, data):
|
||||
x, pos, batch = data
|
||||
@ -40,9 +40,10 @@ class PointNet2GlobalSAModule(torch.nn.Module):
|
||||
One group with all input points, can be viewed as a simple PointNet module.
|
||||
It also return the only one output point(set as origin point).
|
||||
'''
|
||||
def __init__(self, mlp):
|
||||
def __init__(self, mlp, features=3):
|
||||
super(PointNet2GlobalSAModule, self).__init__()
|
||||
self.mlp = mlp
|
||||
self.features = features
|
||||
|
||||
def forward(self, data):
|
||||
x, pos, batch = data
|
||||
@ -52,7 +53,7 @@ class PointNet2GlobalSAModule(torch.nn.Module):
|
||||
x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1)
|
||||
|
||||
batch_size = x1.shape[0]
|
||||
pos1 = x1.new_zeros((batch_size, GLOBAL_POINT_FEATURES)) # set the output point as origin
|
||||
pos1 = x1.new_zeros((batch_size, self.features)) # set the output point as origin
|
||||
batch1 = torch.arange(batch_size).to(batch.device, batch.dtype)
|
||||
|
||||
return x1, pos1, batch1
|
||||
@ -158,44 +159,47 @@ class PointNet2PartSegmentNet(torch.nn.Module):
|
||||
- https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py
|
||||
- https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py
|
||||
'''
|
||||
def __init__(self, num_classes):
|
||||
def __init__(self, num_classes, with_normals=False):
|
||||
super(PointNet2PartSegmentNet, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.features = 3 if not with_normals else 6
|
||||
|
||||
# SA1
|
||||
sa1_sample_ratio = 0.5
|
||||
sa1_radius = 0.2
|
||||
sa1_max_num_neighbours = 64
|
||||
sa1_mlp = make_mlp(GLOBAL_POINT_FEATURES, [64, 64, 128])
|
||||
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp)
|
||||
sa1_mlp = make_mlp(self.features, [64, 64, 128])
|
||||
self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp,
|
||||
features=self.features)
|
||||
|
||||
# SA2
|
||||
sa2_sample_ratio = 0.25
|
||||
sa2_radius = 0.4
|
||||
sa2_max_num_neighbours = 64
|
||||
sa2_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 256])
|
||||
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp)
|
||||
sa2_mlp = make_mlp(128+self.features, [128, 128, 256])
|
||||
self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp,
|
||||
features=self.features)
|
||||
|
||||
# SA3
|
||||
sa3_mlp = make_mlp(256+GLOBAL_POINT_FEATURES, [256, 512, 1024])
|
||||
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp)
|
||||
sa3_mlp = make_mlp(256+self.features, [256, 512, 1024])
|
||||
self.sa3_module = PointNet2GlobalSAModule(sa3_mlp, self.features)
|
||||
|
||||
##
|
||||
knn_num = GLOBAL_POINT_FEATURES
|
||||
knn_num = self.features
|
||||
|
||||
# FP3, reverse of sa3
|
||||
fp3_knn_num = 1 # After global sa module, there is only one point in point cloud
|
||||
fp3_mlp = make_mlp(1024+256+GLOBAL_POINT_FEATURES, [256, 256])
|
||||
fp3_mlp = make_mlp(1024+256+self.features, [256, 256])
|
||||
self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp)
|
||||
|
||||
# FP2, reverse of sa2
|
||||
fp2_knn_num = knn_num
|
||||
fp2_mlp = make_mlp(256+128+GLOBAL_POINT_FEATURES, [256, 128])
|
||||
fp2_mlp = make_mlp(256+128+self.features, [256, 128])
|
||||
self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp)
|
||||
|
||||
# FP1, reverse of sa1
|
||||
fp1_knn_num = knn_num
|
||||
fp1_mlp = make_mlp(128+GLOBAL_POINT_FEATURES, [128, 128, 128])
|
||||
fp1_mlp = make_mlp(128+self.features, [128, 128, 128])
|
||||
self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp)
|
||||
|
||||
self.fc1 = Lin(128, 128)
|
||||
@ -252,11 +256,12 @@ class PointNet2PartSegmentNet(torch.nn.Module):
|
||||
|
||||
if __name__ == '__main__':
|
||||
num_classes = 10
|
||||
net = PointNet2PartSegmentNet(num_classes)
|
||||
num_features = 6
|
||||
net = PointNet2PartSegmentNet(num_classes, features=num_features)
|
||||
|
||||
#
|
||||
print('Test dense input ..')
|
||||
data1 = torch.rand((2, GLOBAL_POINT_FEATURES, 1024)) # (batch_size, 3, num_points)
|
||||
data1 = torch.rand((2, num_features, 1024)) # (batch_size, 3, num_points)
|
||||
print('data1: ', data1.shape)
|
||||
|
||||
out1 = net(data1)
|
||||
@ -272,7 +277,7 @@ if __name__ == '__main__':
|
||||
data_batch = Data()
|
||||
|
||||
# data_batch.x = None
|
||||
data_batch.pos = torch.cat([torch.rand(pos_num1, GLOBAL_POINT_FEATURES), torch.rand(pos_num2, GLOBAL_POINT_FEATURES)], dim=0)
|
||||
data_batch.pos = torch.cat([torch.rand(pos_num1, num_features), torch.rand(pos_num2, num_features)], dim=0)
|
||||
data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)])
|
||||
|
||||
return data_batch
|
||||
|
Loading…
x
Reference in New Issue
Block a user