From 54a5b48ddc11d4447d3b68bf00530b9e9a1df867 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Mon, 5 Aug 2019 17:53:50 +0200 Subject: [PATCH] redo --- dataset/shapenet.py | 176 ++++++-------------------------------------- main.py | 4 +- predict/predict.py | 8 +- vis/show_seg_res.py | 4 +- 4 files changed, 31 insertions(+), 161 deletions(-) diff --git a/dataset/shapenet.py b/dataset/shapenet.py index 7d079df..4787301 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -21,23 +21,25 @@ def save_names(name_list, path): class CustomShapeNet(InMemoryDataset): categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])} + modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])} - def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None, - headers=True, has_variations=False): + def __init__(self, root, collate_per_segment=True, mode='train', transform=None, pre_filter=None, pre_transform=None, + headers=True, has_variations=False, refresh=False): + assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}' assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations' self.has_headers = headers self.has_variations = has_variations self.collate_per_element = collate_per_segment - self.train = train + self.mode = mode + self.refresh = refresh super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter) - path = self.processed_paths[0] if train else self.processed_paths[-1] - self.data, self.slices = torch.load(path) + self.data, self.slices = self._load_dataset() print("Initialized") @property def raw_file_names(self): # Maybe add more data like validation sets - return ['train', 'test'] + return list(self.modes.keys()) @property def processed_file_names(self): @@ -56,9 +58,18 @@ class CustomShapeNet(InMemoryDataset): def _load_dataset(self): data, slices = None, None + filepath = self.processed_paths[self.modes[self.mode]] + if self.refresh: + try: + os.remove(filepath) + print('Processed Location "Refreshed" (We deleted the Files)') + except FileNotFoundError: + print('You meant to refresh the allready processed dataset, but there were none...') + print('continue processing') + pass + while True: try: - filepath = os.path.join(self.root, self.processed_dir, f'{"train" if self.train else "test"}.pt') data, slices = torch.load(filepath) print('Dataset Loaded') break @@ -80,7 +91,7 @@ class CustomShapeNet(InMemoryDataset): def process(self, delimiter=' '): datasets = defaultdict(list) - idx, data_folder = (0, self.raw_file_names[0]) if self.train else (1, self.raw_file_names[1]) + idx, data_folder = self.modes[self.mode], self.raw_file_names[self.modes[self.mode]] path_to_clouds = os.path.join(self.raw_dir, data_folder) if '.headers' in os.listdir(path_to_clouds): @@ -131,7 +142,7 @@ class CustomShapeNet(InMemoryDataset): points = torch.tensor(src, dtype=None).squeeze() if not len(points.shape) > 1: continue - y_all = [y_raw] * points.shape[0] + y_all = ([y_raw] if self.mode != 'predict' else [-1]) * points.shape[0] y = torch.as_tensor(y_all, dtype=torch.int) if self.collate_per_element: data = Data(y=y, pos=points[:, :3]) # , points=points, norm=points[:, 3:]) @@ -167,12 +178,12 @@ class ShapeNetPartSegDataset(Dataset): Resample raw point cloud to fixed number of points. Map raw label from range [1, N] to [0, N-1]. """ - def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, + def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, refresh=False, has_variations=False, npoints=1024, headers=True): super(ShapeNetPartSegDataset, self).__init__() self.npoints = npoints - self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment, - train=train, transform=transform, headers=headers, has_variations=has_variations) + self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment, refresh=refresh, + mode=mode, transform=transform, headers=headers, has_variations=has_variations) def __getitem__(self, index): data = self.dataset[index] @@ -200,144 +211,3 @@ class ShapeNetPartSegDataset(Dataset): def num_classes(self): return self.dataset.num_classes - - -class PredictionShapeNet(InMemoryDataset): - - def __init__(self, root, transform=None, pre_filter=None, pre_transform=None, headers=True, refresh=False): - self.has_headers = headers - self.refresh = refresh - super(PredictionShapeNet, self).__init__(root, transform, pre_transform, pre_filter) - path = self.processed_paths[0] - self.data, self.slices = self._load_dataset() - print("Initialized") - - @property - def raw_file_names(self): - # Maybe add more data like validation sets - return ['predict'] - - @property - def processed_file_names(self): - return [f'{x}.pt' for x in self.raw_file_names] - - def download(self): - dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) - print(f'{dir_count} folders have been found....') - if dir_count: - return dir_count - raise IOError("No raw pointclouds have been found.") - - @property - def num_classes(self): - return len(self.categories) - - def _load_dataset(self): - data, slices = None, None - filepath = os.path.join(self.processed_dir, self.processed_file_names[0]) - if self.refresh: - try: - os.remove(filepath) - print('Processed Location "Refreshed" (We deleted the Files)') - except FileNotFoundError: - print('You meant to refresh the allready processed dataset, but there were none...') - print('continue processing') - pass - - while True: - try: - data, slices = torch.load(filepath) - print('Dataset Loaded') - break - except FileNotFoundError: - self.process() - continue - return data, slices - - def process(self, delimiter=' '): - - datasets, filenames = defaultdict(list), [] - path_to_clouds = os.path.join(self.raw_dir, self.raw_file_names[0]) - - if '.headers' in os.listdir(path_to_clouds): - self.has_headers = True - elif 'no.headers' in os.listdir(path_to_clouds): - self.has_headers = False - else: - pass - - for pointcloud in tqdm(os.scandir(path_to_clouds)): - if not os.path.isdir(pointcloud): - continue - full_cloud_pattern = '(^\d+?_|^)pc\.(xyz|dat)' - pattern = re.compile(full_cloud_pattern) - for file in os.scandir(pointcloud.path): - if not pattern.match(file.name): - continue - with open(file, 'r') as f: - if self.has_headers: - headers = f.__next__() - # Check if there are no useable nodes in this file, header says 0. - if not int(headers.rstrip().split(delimiter)[0]): - continue - - # Iterate over all rows - src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 - for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != ''] - points = torch.tensor(src, dtype=None).squeeze() - if not len(points.shape) > 1: - continue - # pos = points[:, :3] - # norm = points[:, 3:] - y_fake_all = [-1] * points.shape[0] - y = torch.as_tensor(y_fake_all, dtype=torch.int) - # points = torch.as_tensor(points, dtype=torch.float) - # norm = torch.as_tensor(norm, dtype=torch.float) - data = Data(y=y, pos=points[:, :3], points=points, norm=points[:, 3:]) - # , points=points, norm=points[:3], ) - # ToDo: ANy filter to apply? Then do it here. - if self.pre_filter is not None and not self.pre_filter(data): - data = self.pre_filter(data) - raise NotImplementedError - # ToDo: ANy transformation to apply? Then do it here. - if self.pre_transform is not None: - data = self.pre_transform(data) - raise NotImplementedError - datasets[self.raw_file_names[0]].append(data) - filenames.append(file) - - os.makedirs(self.processed_dir, exist_ok=True) - torch.save(self.collate(datasets[self.raw_file_names[0]]), self.processed_paths[0]) - # save_names(filenames) - - def __repr__(self): - return f'{self.__class__.__name__}({len(self)})' - - -class PredictNetPartSegDataset(Dataset): - """ - Resample raw point cloud to fixed number of points. - Map raw label from range [1, N] to [0, N-1]. - """ - def __init__(self, root_dir, num_classes, transform=None, npoints=2048, headers=True, refresh=False): - super(PredictNetPartSegDataset, self).__init__() - self.npoints = npoints - self._num_classes = num_classes - self.dataset = PredictionShapeNet(root=root_dir, transform=transform, headers=headers, refresh=refresh) - - def __getitem__(self, index): - data = self.dataset[index] - points, labels, _, norm = data.pos, data.y, data.points, data.norm - - sample = { - 'points': points, # torch.Tensor (n, 3) - 'labels': labels, # torch.Tensor (n,) - 'normals': norm # torch.Tensor (n,) - } - return sample - - def __len__(self): - return len(self.dataset) - - def num_classes(self): - return self._num_classes diff --git a/main.py b/main.py index ad49a03..474a78b 100644 --- a/main.py +++ b/main.py @@ -73,12 +73,12 @@ if __name__ == '__main__': test_transform = GT.Compose([GT.NormalizeScale(), ]) dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, - train=True, transform=train_transform, npoints=opt.npoints, + mode='train', transform=train_transform, npoints=opt.npoints, has_variations=opt.has_variations, headers=opt.headers) dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, - train=False, transform=test_transform, npoints=opt.npoints, + mode='test', transform=test_transform, npoints=opt.npoints, has_variations=opt.has_variations, headers=opt.headers) test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) diff --git a/predict/predict.py b/predict/predict.py index a28d5aa..9e14e27 100644 --- a/predict/predict.py +++ b/predict/predict.py @@ -2,7 +2,7 @@ import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory -from dataset.shapenet import PredictNetPartSegDataset, ShapeNetPartSegDataset +from dataset.shapenet import ShapeNetPartSegDataset from model.pointnet2_part_seg import PointNet2PartSegmentNet import torch_geometric.transforms as GT import torch @@ -13,7 +13,7 @@ import argparse parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='data', help='dataset path') parser.add_argument('--npoints', type=int, default=2048, help='resample points number') -parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_249.pth', help='model path') +parser.add_argument('--model', type=str, default='./checkpoint/seg_model_custom_246.pth', help='model path') parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result') opt = parser.parse_args() print(opt) @@ -24,9 +24,9 @@ if __name__ == '__main__': print('Construct dataset ..') test_transform = GT.Compose([GT.NormalizeScale(),]) - test_dataset = PredictNetPartSegDataset( + test_dataset = ShapeNetPartSegDataset( + mode='predict', root_dir=opt.dataset, - num_classes=4, transform=None, npoints=opt.npoints, refresh=True diff --git a/vis/show_seg_res.py b/vis/show_seg_res.py index 72bc5a3..130d6a7 100644 --- a/vis/show_seg_res.py +++ b/vis/show_seg_res.py @@ -28,10 +28,10 @@ if __name__ == '__main__': print('Construct dataset ..') test_transform = GT.Compose([GT.NormalizeScale(),]) - test_dataset = PredictNetPartSegDataset( + test_dataset = ShapeNetPartSegDataset( root_dir=opt.dataset, collate_per_segment=False, - train=False, + mode='predict', transform=test_transform, npoints=opt.npoints )