Can now be trained with normals

This commit is contained in:
Si11ium 2019-08-09 15:37:26 +02:00
parent 92117328ad
commit 167ac4991e
2 changed files with 13 additions and 5 deletions

View File

@ -148,6 +148,8 @@ class CustomShapeNet(InMemoryDataset):
####################################
# This is where you define the keys
attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
if not self.with_normals:
attr_dict.update(normals=points[:, 3:6])
####################################
if self.collate_per_element:
data = Data(**attr_dict)

View File

@ -28,8 +28,9 @@ def eval_sample(net, sample):
# points: (n, 3)
points, gt_label = sample['points'], sample['labels']
n = points.shape[0]
f = points.shape[1]
points = points.view(1, n, 3) # make a batch
points = points.view(1, n, f) # make a batch
points = points.transpose(1, 2).contiguous()
points = points.to(device, dtype)
@ -237,15 +238,16 @@ def draw_sample_data(sample_data, colored_normals = False):
def recreate_folder(folder):
if os.path.exists(folder) and os.path.isdir(folder):
shutil.rmtree(folder)
os.mkdir(folder)
os.makedirs(folder, exist_ok=True)
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
parser = argparse.ArgumentParser()
parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_3.pth', help='model path')
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_1.pth', help='model path')
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False,
help='whether a single pointcloud has variations '
@ -303,6 +305,7 @@ if __name__ == '__main__':
test_dataset = ShapeNetPartSegDataset(
mode='predict',
root_dir='data',
with_normals=opt.with_normals,
npoints=opt.npoints,
refresh=True,
collate_per_segment=opt.collate_per_segment,
@ -318,7 +321,7 @@ if __name__ == '__main__':
dtype = torch.float
# net = PointNetPartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
net.load_state_dict(torch.load(opt.model, map_location=device.type))
net = net.to(device, dtype)
@ -332,7 +335,10 @@ if __name__ == '__main__':
# Predict
pred_label, gt_label = eval_sample(net, sample)
sample_data = np.column_stack((sample["points"].numpy(), sample["normals"].numpy(), pred_label.numpy()))
if opt.with_normals:
sample_data = np.column_stack((sample["points"].numpy(), pred_label.numpy()))
else:
sample_data = np.column_stack((sample["points"].numpy(), sample["normals"], pred_label.numpy()))
draw_sample_data(sample_data, False)