Added normals to prediction DataObject

This commit is contained in:
Si11ium 2019-08-09 12:35:55 +02:00
parent 8eb165f76c
commit 39e5d72226
3 changed files with 21 additions and 22 deletions

View File

@ -38,11 +38,11 @@ class CustomShapeNet(InMemoryDataset):
@property @property
def raw_file_names(self): def raw_file_names(self):
# Maybe add more data like validation sets # Maybe add more data like validation sets
return list(self.modes.keys()) return [self.mode]
@property @property
def processed_file_names(self): def processed_file_names(self):
return [f'{x}.pt' for x in self.raw_file_names] return [f'{self.mode}.pt']
def download(self): def download(self):
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
@ -58,7 +58,7 @@ class CustomShapeNet(InMemoryDataset):
def _load_dataset(self): def _load_dataset(self):
data, slices = None, None data, slices = None, None
filepath = self.processed_paths[self.modes[self.mode]] filepath = self.processed_paths[0]
if self.refresh: if self.refresh:
try: try:
os.remove(filepath) os.remove(filepath)
@ -91,7 +91,7 @@ class CustomShapeNet(InMemoryDataset):
def process(self, delimiter=' '): def process(self, delimiter=' '):
datasets = defaultdict(list) datasets = defaultdict(list)
idx, data_folder = self.modes[self.mode], self.raw_file_names[self.modes[self.mode]] idx, data_folder = self.modes[self.mode], self.raw_file_names[0]
path_to_clouds = os.path.join(self.raw_dir, data_folder) path_to_clouds = os.path.join(self.raw_dir, data_folder)
if '.headers' in os.listdir(path_to_clouds): if '.headers' in os.listdir(path_to_clouds):
@ -111,8 +111,8 @@ class CustomShapeNet(InMemoryDataset):
paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}'))) paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}')))
for element in paths: for element in paths:
# This was build to filter all variations that aregreater then 25 # This was build to filter all full clouds
pattern = re.compile('^((6[0-1]|[1-5][0-9])_\w+?\d+?|\d+?_pc)\.(xyz|dat)$') pattern = re.compile('^\d+?_pc\.(xyz|dat)$')
if pattern.match(os.path.split(element)[-1]): if pattern.match(os.path.split(element)[-1]):
continue continue
else: else:
@ -143,9 +143,8 @@ class CustomShapeNet(InMemoryDataset):
y_all = [-1] * points.shape[0] y_all = [-1] * points.shape[0]
y = torch.as_tensor(y_all, dtype=torch.int) y = torch.as_tensor(y_all, dtype=torch.int)
attr_dict = dict(y=y, pos=points[:, :3]) # This is where you define the keys
if self.mode == 'predict': attr_dict = dict(y=y, pos=points[:, :3]) # , normals=points[:, 3:6])
attr_dict.update(normals=points[:, 3:6])
if self.collate_per_element: if self.collate_per_element:
data = Data(**attr_dict) data = Data(**attr_dict)
else: else:
@ -162,14 +161,14 @@ class CustomShapeNet(InMemoryDataset):
cloud_variations[int(os.path.split(element)[-1].split('_')[0])].append(data) cloud_variations[int(os.path.split(element)[-1].split('_')[0])].append(data)
if not self.collate_per_element: if not self.collate_per_element:
if self.has_variations: if self.has_variations:
for variation in cloud_variations.keys(): for _ in cloud_variations.keys():
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
else: else:
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[data_folder]: if datasets[data_folder]:
os.makedirs(self.processed_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True)
torch.save(self.collate(datasets[data_folder]), self.processed_paths[idx]) torch.save(self.collate(datasets[data_folder]), self.processed_paths[0])
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}({len(self)})' return f'{self.__class__.__name__}({len(self)})'
@ -198,18 +197,17 @@ class ShapeNetPartSegDataset(Dataset):
except ValueError: except ValueError:
choice = [] choice = []
points, labels = data.pos[choice, :], data.y[choice] # pos, normals, labels = data.pos[choice, :], data.normals[choice, :], data.y[choice]
pos, labels = data.pos[choice, :], data.y[choice]
labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1]
sample = { sample = {
'points': points, # torch.Tensor (n, 3) 'points': torch.cat([pos], dim=1), # torch.Tensor (n, 6)
'labels': labels # torch.Tensor (n,) 'labels': labels, # torch.Tensor (n,)
# 'pos': pos, # torch.Tensor (n, 3)
# 'normals': normals # torch.Tensor (n, 3)
} }
if self.mode == 'predict':
normals = data.normals[choice, :]
sample.update(normals=normals)
return sample return sample
def __len__(self): def __len__(self):

View File

@ -36,7 +36,7 @@ parser.add_argument('--outf', type=str, default='checkpoint', help='output folde
parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location') parser.add_argument('--labels_within', type=strtobool, default=True, help='defines the label location')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size') parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number') parser.add_argument('--test_per_batches', type=int, default=1000, help='run a test batch per training batches number')
parser.add_argument('--num_workers', type=int, default=1, help='number of data loading workers') parser.add_argument('--num_workers', type=int, default=0, help='number of data loading workers')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False, parser.add_argument('--has_variations', type=strtobool, default=False,
@ -130,10 +130,11 @@ if __name__ == '__main__':
net.train() net.train()
# ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down # ToDo: We need different dataloader here to train the network in multiple iterations, maybe move the loop down
for batch_idx, sample in enumerate(dataLoader): for batch_idx, sample in enumerate(dataLoader):
# points: (batch_size, n, 3) # points: (batch_size, n, 6)
# pos: (batch_size, n, 3)
# labels: (batch_size, n) # labels: (batch_size, n)
points, labels = sample['points'], sample['labels'] points, labels = sample['points'], sample['labels']
points = points.transpose(1, 2).contiguous() # (batch_size, 3, n) points = points.transpose(1, 2).contiguous() # (batch_size, 3/6, n)
points, labels = points.to(device, dtype), labels.to(device, torch.long) points, labels = points.to(device, dtype), labels.to(device, torch.long)
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -8,7 +8,7 @@ from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.data.data import Data from torch_geometric.data.data import Data
from torch_scatter import scatter_add, scatter_max from torch_scatter import scatter_add, scatter_max
GLOBAL_POINT_FEATURES = 6 GLOBAL_POINT_FEATURES = 3
class PointNet2SAModule(torch.nn.Module): class PointNet2SAModule(torch.nn.Module):
def __init__(self, sample_radio, radius, max_num_neighbors, mlp): def __init__(self, sample_radio, radius, max_num_neighbors, mlp):