Added normals to prediction DataObject
This commit is contained in:
7
main.py
7
main.py
@ -36,7 +36,7 @@ parser.add_argument('--outf', type=str, default='checkpoint', help='output folde
|
||||
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
|
||||
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||
parser.add_argument('--num_workers', type=int, default=1, help='number of data loading workers')
|
||||
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||
@ -130,10 +130,11 @@ if __name__ == '__main__':
|
||||
net.train()
|
||||
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
|
||||
for batch_idx, sample in enumerate(dataLoader):
|
||||
# points: (batch_size, n, 3)
|
||||
# points: (batch_size, n, 6)
|
||||
# pos: (batch_size, n, 3)
|
||||
# labels: (batch_size, n)
|
||||
points, labels = sample['points'], sample['labels']
|
||||
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
|
||||
points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n)
|
||||
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
Reference in New Issue
Block a user