i was silly
This commit is contained in:
parent
1af300988a
commit
376d8f0d7c
@ -134,11 +134,12 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
y_all = [y_raw] * points.shape[0]
|
y_all = [y_raw] * points.shape[0]
|
||||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
data = Data(y=y, pos=points[:, :3], points=points, norm=points[:, 3:])
|
data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:])
|
||||||
else:
|
else:
|
||||||
if not data:
|
if not data:
|
||||||
data = defaultdict(list)
|
data = defaultdict(list)
|
||||||
for key, val in dict(y=y, pos=points[:, :3], points=points, norm=points[:, 3:]).items():
|
# points=points, norm=points[:, 3:]
|
||||||
|
for key, val in dict(y=y, pos=points[:, :3]).items():
|
||||||
data[key].append(val)
|
data[key].append(val)
|
||||||
|
|
||||||
data = self._transform_and_filter(data)
|
data = self._transform_and_filter(data)
|
||||||
@ -175,7 +176,7 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
data = self.dataset[index]
|
data = self.dataset[index]
|
||||||
points, labels, _, norm = data.pos, data.y, data.points, data.norm
|
points, labels = data.pos, data.y # , data.points, data.norm
|
||||||
|
|
||||||
# Resample to fixed number of points
|
# Resample to fixed number of points
|
||||||
try:
|
try:
|
||||||
@ -183,14 +184,13 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
choice = []
|
choice = []
|
||||||
|
|
||||||
points, labels, norm = points[choice, :], labels[choice], norm[choice]
|
points, labels = points[choice, :], labels[choice]
|
||||||
|
|
||||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
'points': points, # torch.Tensor (n, 3)
|
'points': points, # torch.Tensor (n, 3)
|
||||||
'labels': labels, # torch.Tensor (n,)
|
'labels': labels # torch.Tensor (n,)
|
||||||
'normals': norm # torch.Tensor (n,)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
Loading…
x
Reference in New Issue
Block a user