Can now be trained with normals
This commit is contained in:
8
main.py
8
main.py
@ -38,6 +38,7 @@ parser.add_argument('--batch_size', type=int, default=8, help='input batch size'
|
||||
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||
parser.add_argument('--with_normals', type=strtobool, default=True, help='if training will include normals')
|
||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||
help='whether a single pointcloud has variations '
|
||||
@ -69,7 +70,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
TransTransform = GT.RandomTranslate(trans_max_distance)
|
||||
train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform])
|
||||
train_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||
|
||||
params = dict(root_dir=opt.dataset,
|
||||
@ -78,7 +79,8 @@ if __name__ == '__main__':
|
||||
npoints=opt.npoints,
|
||||
labels_within=opt.labels_within,
|
||||
has_variations=opt.has_variations,
|
||||
headers=opt.headers
|
||||
headers=opt.headers,
|
||||
with_normals=opt.with_normals
|
||||
)
|
||||
|
||||
dataset = ShapeNetPartSegDataset(mode='train', **params)
|
||||
@ -105,7 +107,7 @@ if __name__ == '__main__':
|
||||
dtype = torch.float
|
||||
print('cudnn.enabled: ', torch.backends.cudnn.enabled)
|
||||
|
||||
net = PointNet2PartSegmentNet(num_classes)
|
||||
net = PointNet2PartSegmentNet(num_classes, with_normals=opt.with_normals)
|
||||
|
||||
if opt.model != '':
|
||||
net.load_state_dict(torch.load(opt.model))
|
||||
|
Reference in New Issue
Block a user