import pickle from pathlib import Path from typing import Union from warnings import warn from collections import defaultdict import os from torch.utils.data import Dataset from tqdm import tqdm import torch from torch_geometric.data import InMemoryDataset from torch_geometric.data import Data from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes def save_names(name_list, path): with open(path, 'wb') as f: f.writelines(name_list) class CustomShapeNet(InMemoryDataset): name = 'CustomShapeNet' def download(self): pass @property def categories(self): return {key: val for val, key in self.classes.items()} @property def modes(self): return {key: val for val, key in dataSplit.items()} @property def cluster_types(self): return {key: val for val, key in clusterTypes.items()} @property def raw_dir(self): return self.root / 'raw' @property def raw_file_names(self): return [self.mode] @property def processed_dir(self): return self.root / 'processed' def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None, pre_transform=None, refresh=False, cluster_type: Union[str, None] = '', poly_as_plane=False): assert mode in self.modes.keys(), \ f'"mode" must be one of {self.modes.keys()}' assert cluster_type in self.cluster_types.keys() or cluster_type is None, \ f'"cluster_type" must be one of {self.cluster_types.keys()} or None, but was: {cluster_type}' # Set the Dataset Parameters self.cluster_type = cluster_type if cluster_type else 'pc' self.poly_as_plane = poly_as_plane self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane self.collate_per_segment = collate_per_segment self.mode = mode self.refresh = refresh root_dir = Path(root_dir) super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter) self.data, self.slices = self._load_dataset() print("Initialized") @property def processed_file_names(self): return [f'{self.mode}.pt'] def check_and_resolve_cloud_count(self): if (self.raw_dir / self.mode).exists(): file_count = len([cloud for cloud in (self.raw_dir / self.mode).iterdir() if cloud.is_file()]) if file_count: print(f'{file_count} files have been found....') return file_count else: warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?")) return file_count warn(ResourceWarning("The raw data folder does not exist. Was this intentional?")) return -1 @property def num_classes(self): return len(self.categories) @property def _class_map_all(self): return {0: 0, 1: 1, 2: None, 3: 2, 4: 3, 5: None, 6: 4, 7: None } @property def _class_map_poly_as_plane(self): return {0: 0, 1: 1, 2: None, 3: 2, 4: 2, 5: None, 6: 2, 7: None } @property def class_remap(self): return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane def _load_dataset(self): data, slices = None, None filepath = self.processed_paths[0] config_path = Path(filepath).parent / f'{self.mode}_params.ini' if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict: with config_path.open('rb') as f: config = pickle.load(f) if config == self._build_config(): pass else: print('The given data parameters seem to differ from the one used to process the dataset:') self.refresh = True if self.refresh: try: os.remove(filepath) try: config_path.unlink() except FileNotFoundError: pass print('Processed Location "Refreshed" (We deleted the Files)') except FileNotFoundError: print('The allready processed dataset was meant to be refreshed, but there was none...') print('continue processing') pass while True: try: data, slices = torch.load(filepath) print(f'{self.mode.title()}-Dataset Loaded') break except FileNotFoundError: status = self.check_and_resolve_cloud_count() if status in [0, -1]: print(f'No dataset was loaded, status: {status}') break self.process() continue if not self.mode == dataSplit.predict: config = self._build_config() with config_path.open('wb') as f: pickle.dump(config, f, pickle.HIGHEST_PROTOCOL) return data, slices def _build_config(self): conf_dict = {key: str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [ 'classes', 'refresh', 'transform', 'data', 'slices' ]} return conf_dict def _pre_transform_and_filter(self, data): if self.pre_filter is not None and not self.pre_filter(data): data = self.pre_filter(data) if self.pre_transform is not None: data = self.pre_transform(data) return data def process(self, delimiter=' '): datasets = defaultdict(list) path_to_clouds = self.raw_dir / self.mode found_clouds = list(path_to_clouds.glob('*.xyz')) if len(found_clouds): for pointcloud in tqdm(found_clouds): if self.cluster_type not in pointcloud.name: continue data = None with pointcloud.open('r') as f: src = defaultdict(list) # Iterate over all rows for row in f: if row != '': vals = row.rstrip().split(delimiter)[None:None] vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals] if len(vals) < 6: raise ValueError('Check the Input!!!!!!') # Expand the values from the csv by fake labels if non are provided. vals = vals + [0] * (8 - len(vals)) vals[-2] = float(self.class_remap[int(vals[-2])]) src[vals[-1]].append(vals) # Switch from un-pickable Defaultdict to Standard Dict src = dict(src) # Transform the Dict[List] to Dict[torch.Tensor] for key, values in list(src.items()): src[key] = torch.tensor(values, dtype=torch.double).squeeze() if src[key].ndim == 2: pass else: del src[key] # Screw the Sorting and make it a FullCloud rather than a seperated if not self.collate_per_segment: try: src = dict( all=torch.cat(tuple(src.values())) ) except RuntimeError: print('debugg') # Transform Box and Polytope to Plane if poly_as_plane is set for key, tensor in src.items(): if tensor.ndim == 1: if all([x == 0 for x in tensor]): continue tensor = tensor.unsqueeze(0) src[key] = tensor for key, values in src.items(): try: points = values[:, :-2] except IndexError: continue y = torch.as_tensor(values[:, -2], dtype=torch.long) y_c = torch.as_tensor(values[:, -1], dtype=torch.long) #################################### # This is where you define the keys attr_dict = dict( y=y, y_c=y_c, pos=points[:, :3], norm=points[:, 3:6] ) #################################### if self.collate_per_segment: data = Data(**attr_dict) else: if data is None: data = defaultdict(list) for attr_key, val in attr_dict.items(): data[attr_key].append(val) # data = self._pre_transform_and_filter(data) if self.collate_per_segment: datasets[self.mode].append(data) if not self.collate_per_segment: datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if datasets[self.mode]: os.makedirs(self.processed_dir, exist_ok=True) collated_dataset = self.collate(datasets[self.mode]) torch.save(collated_dataset, self.processed_paths[0]) def __repr__(self): return f'{self.__class__.__name__}({len(self)})' class ShapeNetPartSegDataset(Dataset): """ Resample raw point cloud to fixed number of points. Map raw label from range [1, N] to [0, N-1]. """ name = 'ShapeNetPartSegDataset' def __init__(self, root_dir, mode='train', **kwargs): super(ShapeNetPartSegDataset, self).__init__() self.mode = mode kwargs.update(dict(root_dir=root_dir, mode=self.mode)) # self.npoints = npoints self.dataset = CustomShapeNet(**kwargs) self.classes = self.dataset.classes def __getitem__(self, index): data = self.dataset[index] # Resample to fixed number of points ''' try: npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0] choice = np.random.choice(data.pos.shape[0], npoints, replace=False if self.mode == DataSplit.predict else True ) except ValueError: choice = [] pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice] # y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1] data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6) y=y, # torch.Tensor (n,) norm=norm # torch.Tensor (n, 3/0) ) ) ''' return data def __len__(self): return len(self.dataset) def num_classes(self): return self.dataset.num_classes