Can now be trained with normals

This commit is contained in:
Si11ium
2019-08-09 13:32:55 +02:00
parent a501dcd6b0
commit 92117328ad
3 changed files with 37 additions and 29 deletions

View File

@ -38,6 +38,7 @@ parser.add_argument('--batch_size', type=int, default=8, help='input batch size'
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False,
help='whether a single pointcloud has variations '
@ -69,7 +70,7 @@ if __name__ == '__main__':
)
TransTransform = GT.RandomTranslate(trans_max_distance)
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
train_transform = GT.Compose([GT.NormalizeScale(), ])
test_transform = GT.Compose([GT.NormalizeScale(), ])
params = dict(root_dir=opt.dataset,
@ -78,7 +79,8 @@ if __name__ == '__main__':
npoints=opt.npoints,
labels_within=opt.labels_within,
has_variations=opt.has_variations,
headers=opt.headers
headers=opt.headers,
with_normals=opt.with_normals
)
dataset = ShapeNetPartSegDataset(mode='train', **params)
@ -105,7 +107,7 @@ if __name__ == '__main__':
dtype = torch.float
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
net = PointNet2PartSegmentNet(num_classes)
net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
if opt.model != '':
net.load_state_dict(torch.load(opt.model))