Added normals to prediction DataObject
This commit is contained in:
parent
8eb165f76c
commit
39e5d72226
@ -38,11 +38,11 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
@property
|
@property
|
||||||
def raw_file_names(self):
|
def raw_file_names(self):
|
||||||
# Maybe add more data like validation sets
|
# Maybe add more data like validation sets
|
||||||
return list(self.modes.keys())
|
return [self.mode]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def processed_file_names(self):
|
def processed_file_names(self):
|
||||||
return [f'{x}.pt' for x in self.raw_file_names]
|
return [f'{self.mode}.pt']
|
||||||
|
|
||||||
def download(self):
|
def download(self):
|
||||||
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
|
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
|
||||||
@ -58,7 +58,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
data, slices = None, None
|
data, slices = None, None
|
||||||
filepath = self.processed_paths[self.modes[self.mode]]
|
filepath = self.processed_paths[0]
|
||||||
if self.refresh:
|
if self.refresh:
|
||||||
try:
|
try:
|
||||||
os.remove(filepath)
|
os.remove(filepath)
|
||||||
@ -91,7 +91,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
def process(self, delimiter=' '):
|
def process(self, delimiter=' '):
|
||||||
datasets = defaultdict(list)
|
datasets = defaultdict(list)
|
||||||
idx, data_folder = self.modes[self.mode], self.raw_file_names[self.modes[self.mode]]
|
idx, data_folder = self.modes[self.mode], self.raw_file_names[0]
|
||||||
path_to_clouds = os.path.join(self.raw_dir, data_folder)
|
path_to_clouds = os.path.join(self.raw_dir, data_folder)
|
||||||
|
|
||||||
if '.headers' in os.listdir(path_to_clouds):
|
if '.headers' in os.listdir(path_to_clouds):
|
||||||
@ -111,8 +111,8 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}')))
|
paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}')))
|
||||||
|
|
||||||
for element in paths:
|
for element in paths:
|
||||||
# This was build to filter all variations that aregreater then 25
|
# This was build to filter all full clouds
|
||||||
pattern = re.compile('^((6[0-1]|[1-5][0-9])_\w+?\d+?|\d+?_pc)\.(xyz|dat)$')
|
pattern = re.compile('^\d+?_pc\.(xyz|dat)$')
|
||||||
if pattern.match(os.path.split(element)[-1]):
|
if pattern.match(os.path.split(element)[-1]):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -143,9 +143,8 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
y_all = [-1] * points.shape[0]
|
y_all = [-1] * points.shape[0]
|
||||||
|
|
||||||
y = torch.as_tensor(y_all, dtype=torch.int)
|
y = torch.as_tensor(y_all, dtype=torch.int)
|
||||||
attr_dict = dict(y=y, pos=points[:, :3])
|
# This is where you define the keys
|
||||||
if self.mode == 'predict':
|
attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6])
|
||||||
attr_dict.update(normals=points[:, 3:6])
|
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
data = Data(**attr_dict)
|
data = Data(**attr_dict)
|
||||||
else:
|
else:
|
||||||
@ -162,14 +161,14 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
cloud_variations[int(os.path.split(element)[-1].split('_')[0])].append(data)
|
cloud_variations[int(os.path.split(element)[-1].split('_')[0])].append(data)
|
||||||
if not self.collate_per_element:
|
if not self.collate_per_element:
|
||||||
if self.has_variations:
|
if self.has_variations:
|
||||||
for variation in cloud_variations.keys():
|
for _ in cloud_variations.keys():
|
||||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
else:
|
else:
|
||||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
|
|
||||||
if datasets[data_folder]:
|
if datasets[data_folder]:
|
||||||
os.makedirs(self.processed_dir, exist_ok=True)
|
os.makedirs(self.processed_dir, exist_ok=True)
|
||||||
torch.save(self.collate(datasets[data_folder]), self.processed_paths[idx])
|
torch.save(self.collate(datasets[data_folder]), self.processed_paths[0])
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({len(self)})'
|
return f'{self.__class__.__name__}({len(self)})'
|
||||||
@ -198,18 +197,17 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
choice = []
|
choice = []
|
||||||
|
|
||||||
points, labels = data.pos[choice, :], data.y[choice]
|
# pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
|
||||||
|
pos, labels = data.pos[choice, :], data.y[choice]
|
||||||
|
|
||||||
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
sample = {
|
sample = {
|
||||||
'points': points, # torch.Tensor (n, 3)
|
'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6)
|
||||||
'labels': labels # torch.Tensor (n,)
|
'labels': labels, # torch.Tensor (n,)
|
||||||
|
# 'pos': pos, # torch.Tensor (n, 3)
|
||||||
|
# 'normals': normals # torch.Tensor (n, 3)
|
||||||
}
|
}
|
||||||
if self.mode == 'predict':
|
|
||||||
normals = data.normals[choice, :]
|
|
||||||
sample.update(normals=normals)
|
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
7
main.py
7
main.py
@ -36,7 +36,7 @@ parser.add_argument('--outf', type=str, default='checkpoint', help='output folde
|
|||||||
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
|
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
|
||||||
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
|
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
|
||||||
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
|
||||||
parser.add_argument('--num_workers', type=int, default=1, help='number of data loading workers')
|
parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
|
||||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||||
parser.add_argument('--has_variations', type=strtobool, default=False,
|
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||||
@ -130,10 +130,11 @@ if __name__ == '__main__':
|
|||||||
net.train()
|
net.train()
|
||||||
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
|
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
|
||||||
for batch_idx, sample in enumerate(dataLoader):
|
for batch_idx, sample in enumerate(dataLoader):
|
||||||
# points: (batch_size, n, 3)
|
# points: (batch_size, n, 6)
|
||||||
|
# pos: (batch_size, n, 3)
|
||||||
# labels: (batch_size, n)
|
# labels: (batch_size, n)
|
||||||
points, labels = sample['points'], sample['labels']
|
points, labels = sample['points'], sample['labels']
|
||||||
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n)
|
points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n)
|
||||||
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
points, labels = points.to(device, dtype), labels.to(device, torch.long)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
|
|||||||
from torch_geometric.data.data import Data
|
from torch_geometric.data.data import Data
|
||||||
from torch_scatter import scatter_add, scatter_max
|
from torch_scatter import scatter_add, scatter_max
|
||||||
|
|
||||||
GLOBAL_POINT_FEATURES = 6
|
GLOBAL_POINT_FEATURES = 3
|
||||||
|
|
||||||
class PointNet2SAModule(torch.nn.Module):
|
class PointNet2SAModule(torch.nn.Module):
|
||||||
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
def __init__(self, sample_radio, radius, max_num_neighbors, mlp):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user