i was silly

This commit is contained in:
Si11ium 2019-08-02 18:41:13 +02:00
parent 1af300988a
commit 376d8f0d7c

View File

@ -134,11 +134,12 @@ class CustomShapeNet(InMemoryDataset):
y_all = [y_raw] * points.shape[0] y_all = [y_raw] * points.shape[0]
y = torch.as_tensor(y_all, dtype=torch.int) y = torch.as_tensor(y_all, dtype=torch.int)
if self.collate_per_element: if self.collate_per_element:
data = Data(y=y, pos=points[:, :3], points=points, norm=points[:, 3:]) data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:])
else: else:
if not data: if not data:
data = defaultdict(list) data = defaultdict(list)
for key, val in dict(y=y, pos=points[:, :3], points=points, norm=points[:, 3:]).items(): # points=points, norm=points[:, 3:]
for key, val in dict(y=y, pos=points[:, :3]).items():
data[key].append(val) data[key].append(val)
data = self._transform_and_filter(data) data = self._transform_and_filter(data)
@ -175,7 +176,7 @@ class ShapeNetPartSegDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
data = self.dataset[index] data = self.dataset[index]
points, labels, _, norm = data.pos, data.y, data.points, data.norm points, labels = data.pos, data.y # , data.points, data.norm
# Resample to fixed number of points # Resample to fixed number of points
try: try:
@ -183,14 +184,13 @@ class ShapeNetPartSegDataset(Dataset):
except ValueError: except ValueError:
choice = [] choice = []
points, labels, norm = points[choice, :], labels[choice], norm[choice] points, labels = points[choice, :], labels[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
sample = { sample = {
'points': points, # torch.Tensor (n, 3) 'points': points, # torch.Tensor (n, 3)
'labels': labels, # torch.Tensor (n,) 'labels': labels # torch.Tensor (n,)
'normals': norm # torch.Tensor (n,)
} }
return sample return sample