Added normals to prediction DataObject
This commit is contained in:
parent
39e5d72226
commit
a501dcd6b0
@ -144,7 +144,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||
# This is where you define the keys
|
||||
attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6])
|
||||
attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6])
|
||||
if self.collate_per_element:
|
||||
data = Data(**attr_dict)
|
||||
else:
|
||||
@ -193,20 +193,20 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
# Resample to fixed number of points
|
||||
try:
|
||||
npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0]
|
||||
choice = np.random.choice(data.pos.shape[0], npoints, replace=False)
|
||||
choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True)
|
||||
except ValueError:
|
||||
choice = []
|
||||
|
||||
# pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
||||
pos, labels = data.pos[choice, :], data.y[choice]
|
||||
pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
||||
# pos, labels = data.pos[choice, :], data.y[choice]
|
||||
|
||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||
|
||||
sample = {
|
||||
'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6)
|
||||
'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6)
|
||||
'labels': labels, # torch.Tensor (n,)
|
||||
# 'pos': pos, # torch.Tensor (n, 3)
|
||||
# 'normals': normals # torch.Tensor (n, 3)
|
||||
'pos': pos, # torch.Tensor (n, 3)
|
||||
'normals': normals # torch.Tensor (n, 3)
|
||||
}
|
||||
return sample
|
||||
|
||||
|
@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
||||
from torch_geometric.data.data import Data
|
||||
from torch_scatter import scatter_add, scatter_max
|
||||
|
||||
GLOBAL_POINT_FEATURES = 3
|
||||
GLOBAL_POINT_FEATURES = 6
|
||||
|
||||
class PointNet2SAModule(torch.nn.Module):
|
||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||
|
Loading…
x
Reference in New Issue
Block a user