import numpy as np from collections import defaultdict import os from tqdm import tqdm import glob import torch from torch_geometric.data import InMemoryDataset from torch_geometric.data import Data from torch.utils.data import Dataset import re def save_names(name_list, path): with open(path, 'wb') as f: f.writelines(name_list) class CustomShapeNet(InMemoryDataset): categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])} modes = {key: val for val, key in enumerate(['train', 'test', 'predict'])} def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None, pre_transform=None, headers=True, has_variations=False, refresh=False, labels_within=False, with_normals=False): assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}' assert not (collate_per_segment and has_variations), 'Either use each element or pointclouds - with variations' #Set the Dataset Parameters self.has_headers, self.has_variations, self.labels_within = headers, has_variations, labels_within self.collate_per_element, self.mode, self.refresh = collate_per_segment, mode, refresh self.with_normals = with_normals super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter) self.data, self.slices = self._load_dataset() print("Initialized") @property def raw_file_names(self): # Maybe add more data like validation sets return [self.mode] @property def processed_file_names(self): return [f'{self.mode}.pt'] def download(self): dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))]) if dir_count: print(f'{dir_count} folders have been found....') return dir_count raise IOError("No raw pointclouds have been found.") @property def num_classes(self): return len(self.categories) def _load_dataset(self): data, slices = None, None filepath = self.processed_paths[0] if self.refresh: try: os.remove(filepath) print('Processed Location "Refreshed" (We deleted the Files)') except FileNotFoundError: print('You meant to refresh the allready processed dataset, but there were none...') print('continue processing') pass while True: try: data, slices = torch.load(filepath) print('Dataset Loaded') break except FileNotFoundError: self.process() continue return data, slices def _transform_and_filter(self, data): # ToDo: ANy filter to apply? Then do it here. if self.pre_filter is not None and not self.pre_filter(data): data = self.pre_filter(data) raise NotImplementedError # ToDo: ANy transformation to apply? Then do it here. if self.pre_transform is not None: data = self.pre_transform(data) raise NotImplementedError return data def process(self, delimiter=' '): datasets = defaultdict(list) idx, data_folder = self.modes[self.mode], self.raw_file_names[0] path_to_clouds = os.path.join(self.raw_dir, data_folder) if '.headers' in os.listdir(path_to_clouds): self.has_headers = True elif 'no.headers' in os.listdir(path_to_clouds): self.has_headers = False else: pass for pointcloud in tqdm(os.scandir(path_to_clouds)): if self.has_variations: cloud_variations = defaultdict(list) if not os.path.isdir(pointcloud): continue data, paths = None, list() for ext in ['dat', 'xyz']: paths.extend(glob.glob(os.path.join(pointcloud.path, f'*.{ext}'))) for element in paths: # This was build to filter all full clouds pattern = re.compile('^\d+?_pc\.(xyz|dat)$') if pattern.match(os.path.split(element)[-1]): continue else: with open(element,'r') as f: if self.has_headers: headers = f.__next__() # Check if there are no useable nodes in this file, header says 0. if not int(headers.rstrip().split(delimiter)[0]): continue # Iterate over all rows src = [[float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in line.rstrip().split(delimiter)[None:None]] for line in f if line != ''] points = torch.tensor(src, dtype=None).squeeze() if not len(points.shape) > 1: continue # Place Fake Labels to hold the given structure if self.labels_within: y_all = points[:, -1] points = points[:, :-1] else: # Get the y - Label if self.mode != 'predict': # TODO: This is shady function, elaborate on it y_raw = next(i for i, v in enumerate(self.categories.keys()) if v.lower() in element.lower()) y_all = [y_raw] * points.shape[0] else: y_all = [-1] * points.shape[0] y = torch.as_tensor(y_all, dtype=torch.int) #################################### # This is where you define the keys attr_dict = dict(y=y, pos=points[:, :3 if not self.with_normals else 6]) #################################### if self.collate_per_element: data = Data(**attr_dict) else: if not data: data = defaultdict(list) # points=points, norm=points[:, 3:] for key, val in attr_dict.items(): data[key].append(val) data = self._transform_and_filter(data) if self.collate_per_element: datasets[data_folder].append(data) if self.has_variations: cloud_variations[int(os.path.split(element)[-1].split('_')[0])].append(data) if not self.collate_per_element: if self.has_variations: for _ in cloud_variations.keys(): datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) else: datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if datasets[data_folder]: os.makedirs(self.processed_dir, exist_ok=True) torch.save(self.collate(datasets[data_folder]), self.processed_paths[0]) def __repr__(self): return f'{self.__class__.__name__}({len(self)})' class ShapeNetPartSegDataset(Dataset): """ Resample raw point cloud to fixed number of points. Map raw label from range [1, N] to [0, N-1]. """ def __init__(self, root_dir, npoints=1024, mode='train', **kwargs): super(ShapeNetPartSegDataset, self).__init__() self.mode = mode kwargs.update(dict(root_dir=root_dir, mode=self.mode)) self.npoints = npoints self.dataset = CustomShapeNet(**kwargs) def __getitem__(self, index): data = self.dataset[index] # Resample to fixed number of points try: npoints = self.npoints if self.mode != 'predict' else data.pos.shape[0] choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == 'predict' else True) except ValueError: choice = [] pos, labels = data.pos[choice, :], data.y[choice] labels -= 1 if self.num_classes() in labels else 0 # Map label from [1, C] to [0, C-1] sample = { 'points': pos, # torch.Tensor (n, 6) 'labels': labels # torch.Tensor (n,) } return sample def __len__(self): return len(self.dataset) def num_classes(self): return self.dataset.num_classes