Can now be trained with normals
This commit is contained in:
parent
92117328ad
commit
167ac4991e
@ -148,6 +148,8 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
####################################
|
####################################
|
||||||
# This is where you define the keys
|
# This is where you define the keys
|
||||||
attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
|
attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6])
|
||||||
|
if not self.with_normals:
|
||||||
|
attr_dict.update(normals=points[:, 3:6])
|
||||||
####################################
|
####################################
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
data = Data(**attr_dict)
|
data = Data(**attr_dict)
|
||||||
|
@ -28,8 +28,9 @@ def eval_sample(net, sample):
|
|||||||
# points: (n, 3)
|
# points: (n, 3)
|
||||||
points, gt_label = sample['points'], sample['labels']
|
points, gt_label = sample['points'], sample['labels']
|
||||||
n = points.shape[0]
|
n = points.shape[0]
|
||||||
|
f = points.shape[1]
|
||||||
|
|
||||||
points = points.view(1, n, 3) # make a batch
|
points = points.view(1, n, f) # make a batch
|
||||||
points = points.transpose(1, 2).contiguous()
|
points = points.transpose(1, 2).contiguous()
|
||||||
points = points.to(device, dtype)
|
points = points.to(device, dtype)
|
||||||
|
|
||||||
@ -237,15 +238,16 @@ def draw_sample_data(sample_data, colored_normals = False):
|
|||||||
def recreate_folder(folder):
|
def recreate_folder(folder):
|
||||||
if os.path.exists(folder) and os.path.isdir(folder):
|
if os.path.exists(folder) and os.path.isdir(folder):
|
||||||
shutil.rmtree(folder)
|
shutil.rmtree(folder)
|
||||||
os.mkdir(folder)
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
|
parser.add_argument('--npoints', type=int, default=2048, help='resample points number')
|
||||||
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_3.pth', help='model path')
|
parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_1.pth', help='model path')
|
||||||
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
|
parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result')
|
||||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||||
|
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
|
||||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||||
help='whether a single pointcloud has variations '
|
help='whether a single pointcloud has variations '
|
||||||
@ -303,6 +305,7 @@ if __name__ == '__main__':
|
|||||||
test_dataset = ShapeNetPartSegDataset(
|
test_dataset = ShapeNetPartSegDataset(
|
||||||
mode='predict',
|
mode='predict',
|
||||||
root_dir='data',
|
root_dir='data',
|
||||||
|
with_normals=opt.with_normals,
|
||||||
npoints=opt.npoints,
|
npoints=opt.npoints,
|
||||||
refresh=True,
|
refresh=True,
|
||||||
collate_per_segment=opt.collate_per_segment,
|
collate_per_segment=opt.collate_per_segment,
|
||||||
@ -318,7 +321,7 @@ if __name__ == '__main__':
|
|||||||
dtype = torch.float
|
dtype = torch.float
|
||||||
|
|
||||||
# net = PointNetPartSegmentNet(num_classes)
|
# net = PointNetPartSegmentNet(num_classes)
|
||||||
net = PointNet2PartSegmentNet(num_classes)
|
net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
|
||||||
|
|
||||||
net.load_state_dict(torch.load(opt.model, map_location=device.type))
|
net.load_state_dict(torch.load(opt.model, map_location=device.type))
|
||||||
net = net.to(device, dtype)
|
net = net.to(device, dtype)
|
||||||
@ -332,7 +335,10 @@ if __name__ == '__main__':
|
|||||||
# Predict
|
# Predict
|
||||||
|
|
||||||
pred_label, gt_label = eval_sample(net, sample)
|
pred_label, gt_label = eval_sample(net, sample)
|
||||||
sample_data = np.column_stack((sample["points"].numpy(), sample["normals"].numpy(), pred_label.numpy()))
|
if opt.with_normals:
|
||||||
|
sample_data = np.column_stack((sample["points"].numpy(), pred_label.numpy()))
|
||||||
|
else:
|
||||||
|
sample_data = np.column_stack((sample["points"].numpy(), sample["normals"], pred_label.numpy()))
|
||||||
|
|
||||||
draw_sample_data(sample_data, False)
|
draw_sample_data(sample_data, False)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user