i was silly
This commit is contained in:
parent
1af300988a
commit
376d8f0d7c
@ -134,11 +134,12 @@ class CustomShapeNet(InMemoryDataset):
|
||||
y_all = [y_raw] * points.shape[0]
|
||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||
if self.collate_per_element:
|
||||
data = Data(y=y, pos=points[:, :3], points=points, norm=points[:, 3:])
|
||||
data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:])
|
||||
else:
|
||||
if not data:
|
||||
data = defaultdict(list)
|
||||
for key, val in dict(y=y, pos=points[:, :3], points=points, norm=points[:, 3:]).items():
|
||||
# points=points, norm=points[:, 3:]
|
||||
for key, val in dict(y=y, pos=points[:, :3]).items():
|
||||
data[key].append(val)
|
||||
|
||||
data = self._transform_and_filter(data)
|
||||
@ -175,7 +176,7 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.dataset[index]
|
||||
points, labels, _, norm = data.pos, data.y, data.points, data.norm
|
||||
points, labels = data.pos, data.y # , data.points, data.norm
|
||||
|
||||
# Resample to fixed number of points
|
||||
try:
|
||||
@ -183,14 +184,13 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
except ValueError:
|
||||
choice = []
|
||||
|
||||
points, labels, norm = points[choice, :], labels[choice], norm[choice]
|
||||
points, labels = points[choice, :], labels[choice]
|
||||
|
||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||
|
||||
sample = {
|
||||
'points': points, # torch.Tensor (n, 3)
|
||||
'labels': labels, # torch.Tensor (n,)
|
||||
'normals': norm # torch.Tensor (n,)
|
||||
'labels': labels # torch.Tensor (n,)
|
||||
}
|
||||
|
||||
return sample
|
||||
|
Loading…
x
Reference in New Issue
Block a user