Added normals to prediction DataObject
This commit is contained in:
parent
39e5d72226
commit
a501dcd6b0
@ -144,7 +144,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||||
# This is where you define the keys
|
# This is where you define the keys
|
||||||
attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6])
|
attr_dict = dict(y=y, pos=points[:, :3], normals=points[:, 3:6])
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
data = Data(**attr_dict)
|
data = Data(**attr_dict)
|
||||||
else:
|
else:
|
||||||
@ -193,20 +193,20 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
# Resample to fixed number of points
|
# Resample to fixed number of points
|
||||||
try:
|
try:
|
||||||
npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0]
|
npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0]
|
||||||
choice = np.random.choice(data.pos.shape[0], npoints, replace=False)
|
choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
choice = []
|
choice = []
|
||||||
|
|
||||||
# pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
||||||
pos, labels = data.pos[choice, :], data.y[choice]
|
# pos, labels = data.pos[choice, :], data.y[choice]
|
||||||
|
|
||||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6)
|
'points': torch.cat([pos, normals], dim=1), # torch.Tensor (n, 6)
|
||||||
'labels': labels, # torch.Tensor (n,)
|
'labels': labels, # torch.Tensor (n,)
|
||||||
# 'pos': pos, # torch.Tensor (n, 3)
|
'pos': pos, # torch.Tensor (n, 3)
|
||||||
# 'normals': normals # torch.Tensor (n, 3)
|
'normals': normals # torch.Tensor (n, 3)
|
||||||
}
|
}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|||||||
from torch_geometric.data.data import Data
|
from torch_geometric.data.data import Data
|
||||||
from torch_scatter import scatter_add, scatter_max
|
from torch_scatter import scatter_add, scatter_max
|
||||||
|
|
||||||
GLOBAL_POINT_FEATURES = 3
|
GLOBAL_POINT_FEATURES = 6
|
||||||
|
|
||||||
class PointNet2SAModule(torch.nn.Module):
|
class PointNet2SAModule(torch.nn.Module):
|
||||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user