Added normals to prediction DataObject
This commit is contained in:
parent
8eb165f76c
commit
39e5d72226
@ -38,11 +38,11 @@ class CustomShapeNet(InMemoryDataset):
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
# Maybe add more data like validation sets
|
||||
return list(self.modes.keys())
|
||||
return [self.mode]
|
||||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
return [f'{x}.pt' for x in self.raw_file_names]
|
||||
return [f'{self.mode}.pt']
|
||||
|
||||
def download(self):
|
||||
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
|
||||
@ -58,7 +58,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
def _load_dataset(self):
|
||||
data, slices = None, None
|
||||
filepath = self.processed_paths[self.modes[self.mode]]
|
||||
filepath = self.processed_paths[0]
|
||||
if self.refresh:
|
||||
try:
|
||||
os.remove(filepath)
|
||||
@ -91,7 +91,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
def process(self, delimiter=' '):
|
||||
datasets = defaultdict(list)
|
||||
idx, data_folder = self.modes[self.mode], self.raw_file_names[self.modes[self.mode]]
|
||||
idx, data_folder = self.modes[self.mode], self.raw_file_names[0]
|
||||
path_to_clouds = os.path.join(self.raw_dir, data_folder)
|
||||
|
||||
if '.headers' in os.listdir(path_to_clouds):
|
||||
@ -111,8 +111,8 @@ class CustomShapeNet(InMemoryDataset):
|
||||
paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}')))
|
||||
|
||||
for element in paths:
|
||||
# This was build to filter all variations that aregreater then 25
|
||||
pattern = re.compile('^((6[0-1]|[1-5][0-9])_\w+?\d+?|\d+?_pc)\.(xyz|dat)$')
|
||||
# This was build to filter all full clouds
|
||||
pattern = re.compile('^\d+?_pc\.(xyz|dat)$')
|
||||
if pattern.match(os.path.split(element)[-1]):
|
||||
continue
|
||||
else:
|
||||
@ -143,9 +143,8 @@ class CustomShapeNet(InMemoryDataset):
|
||||
y_all = [-1] * points.shape[0]
|
||||
|
||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||
attr_dict = dict(y=y, pos=points[:, :3])
|
||||
if self.mode == 'predict':
|
||||
attr_dict.update(normals=points[:, 3:6])
|
||||
# This is where you define the keys
|
||||
attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6])
|
||||
if self.collate_per_element:
|
||||
data = Data(**attr_dict)
|
||||
else:
|
||||
@ -162,14 +161,14 @@ class CustomShapeNet(InMemoryDataset):
|
||||
cloud_variations[int(os.path.split(element)[-1].split('_')[0])].append(data)
|
||||
if not self.collate_per_element:
|
||||
if self.has_variations:
|
||||
for variation in cloud_variations.keys():
|
||||
for _ in cloud_variations.keys():
|
||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
else:
|
||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
|
||||
if datasets[data_folder]:
|
||||
os.makedirs(self.processed_dir, exist_ok=True)
|
||||
torch.save(self.collate(datasets[data_folder]), self.processed_paths[idx])
|
||||
torch.save(self.collate(datasets[data_folder]), self.processed_paths[0])
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({len(self)})'
|
||||
@ -198,18 +197,17 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
except ValueError:
|
||||
choice = []
|
||||
|
||||
points, labels = data.pos[choice, :], data.y[choice]
|
||||
# pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
||||
pos, labels = data.pos[choice, :], data.y[choice]
|
||||
|
||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||
|
||||
sample = {
|
||||
'points': points, # torch.Tensor (n, 3)
|
||||
'labels': labels # torch.Tensor (n,)
|
||||
'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6)
|
||||
'labels': labels, # torch.Tensor (n,)
|
||||
# 'pos': pos, # torch.Tensor (n, 3)
|
||||
# 'normals': normals # torch.Tensor (n, 3)
|
||||
}
|
||||
if self.mode == 'predict':
|
||||
normals = data.normals[choice, :]
|
||||
sample.update(normals=normals)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
|
7
main.py
7
main.py
@ -36,7 +36,7 @@ parser.add_argument('--outf', type=str, default='checkpoint', help='output folde
|
||||
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
|
||||
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||
parser.add_argument('--num_workers', type=int, default=1, help='number of data loading workers')
|
||||
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||
@ -130,10 +130,11 @@ if __name__ == '__main__':
|
||||
net.train()
|
||||
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
|
||||
for batch_idx, sample in enumerate(dataLoader):
|
||||
# points: (batch_size, n, 3)
|
||||
# points: (batch_size, n, 6)
|
||||
# pos: (batch_size, n, 3)
|
||||
# labels: (batch_size, n)
|
||||
points, labels = sample['points'], sample['labels']
|
||||
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
|
||||
points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n)
|
||||
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
||||
from torch_geometric.data.data import Data
|
||||
from torch_scatter import scatter_add, scatter_max
|
||||
|
||||
GLOBAL_POINT_FEATURES = 6
|
||||
GLOBAL_POINT_FEATURES = 3
|
||||
|
||||
class PointNet2SAModule(torch.nn.Module):
|
||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||
|
Loading…
x
Reference in New Issue
Block a user