diff --git a/dataset/shapenet.py b/dataset/shapenet.py index 480c04e..6d6cd57 100644 --- a/dataset/shapenet.py +++ b/dataset/shapenet.py @@ -17,13 +17,15 @@ def save_names(name_list, path): with open(path, 'wb') as f: f.writelines(name_list) + class CustomShapeNet(InMemoryDataset): categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])} def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None, - headers=True, **kwargs): + headers=True, has_variations=False): self.has_headers = headers + self.has_variations = has_variations self.collate_per_element = collate_per_segment self.train = train super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter) @@ -88,6 +90,8 @@ class CustomShapeNet(InMemoryDataset): pass for pointcloud in tqdm(os.scandir(path_to_clouds)): + if self.has_variations: + cloud_variations = defaultdict(list) if not os.path.isdir(pointcloud): continue data, paths = None, list() @@ -131,8 +135,14 @@ class CustomShapeNet(InMemoryDataset): data = self._transform_and_filter(data) if self.collate_per_element: datasets[data_folder].append(data) + if self.has_variations: + cloud_variations[int(element.name.split('_')[0])].append(data) if not self.collate_per_element: - datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) + if self.has_variations: + for variation in cloud_variations.keys(): + datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) + else: + datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if datasets[data_folder]: os.makedirs(self.processed_dir, exist_ok=True) @@ -147,11 +157,12 @@ class ShapeNetPartSegDataset(Dataset): Resample raw point cloud to fixed number of points. Map raw label from range [1, N] to [0, N-1]. """ - def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, npoints=1024, headers=True): + def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, + has_variations=False, npoints=1024, headers=True): super(ShapeNetPartSegDataset, self).__init__() self.npoints = npoints self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment, - train=train, transform=transform, headers=headers) + train=train, transform=transform, headers=headers, has_variations=has_variations) def __getitem__(self, index): data = self.dataset[index] diff --git a/main.py b/main.py index ba572bd..f2ad6dd 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,9 @@ parser.add_argument('--test_per_batches', type=int, default=1000, help='run a te parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') +parser.add_argument('--has_variations', type=strtobool, default=False, + help='whether a single pointcloud has variations ' + 'named int(id)_pc.(xyz|dat) look at pointclouds or sub') opt = parser.parse_args() @@ -70,11 +73,13 @@ if __name__ == '__main__': test_transform = GT.Compose([GT.NormalizeScale(), ]) dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, - train=True, transform=train_transform, npoints=opt.npoints, headers=opt.headers) + train=True, transform=train_transform, npoints=opt.npoints, + has_variations=opt.has_variations, headers=opt.headers) dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, - train=False, transform=test_transform, npoints=opt.npoints, headers=opt.headers) + train=False, transform=test_transform, npoints=opt.npoints, + has_variations=opt.has_variations, headers=opt.headers) test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) num_classes = dataset.num_classes()