Added normals to prediction DataObject

This commit is contained in:
Si11ium
2019-08-09 12:35:55 +02:00
parent 8eb165f76c
commit 39e5d72226
3 changed files with 21 additions and 22 deletions

View File

@ -36,7 +36,7 @@ parser.add_argument('--outf', type=str, default='checkpoint', help='output folde
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
parser.add_argument('--num_workers', type=int, default=1, help='number of data loading workers')
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False,
@ -130,10 +130,11 @@ if __name__ == '__main__':
net.train()
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
for batch_idx, sample in enumerate(dataLoader):
# points: (batch_size, n, 3)
# points: (batch_size, n, 6)
# pos: (batch_size, n, 3)
# labels: (batch_size, n)
points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n)
points, labels = points.to(device, dtype), labels.to(device, torch.long)
optimizer.zero_grad()