import pickle from collections import defaultdict from abc import ABC from pathlib import Path from torch.utils.data import Dataset, ConcatDataset from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling import numpy as np class _Point_Dataset(ABC, Dataset): @property def name(self): raise NotImplementedError @property def sample_shape(self): # FixMe: This does not work when more then x/y tuples are returned return self[0][0].shape headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx'] samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling) def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd', transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs): super(_Point_Dataset, self).__init__() self.setting: str self.dense_output = dense_output self.split = split self.norm_as_feature = norm_as_feature self.load_preprocessed = load_preprocessed self.transforms = transforms if transforms else lambda x: x self.sampling_k = sampling_k self.sampling = self.samplers[sampling](K=self.sampling_k) self.root = Path(root) self.raw = self.root / 'raw' / self.split self.processed_ext = '.pik' self.raw_ext = '.xyz' self.processed = self.root / self.setting self.processed.mkdir(parents=True, exist_ok=True) self._files = list(self.raw.glob(f'*{self.setting}*')) def _read_or_load(self, item): raw_file_path = self._files[item] processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) if not self.load_preprocessed: processed_file_path.unlink(missing_ok=True) if not processed_file_path.exists(): pointcloud = defaultdict(list) with raw_file_path.open('r') as raw_file: for row in raw_file: values = [float(x) for x in row.strip().split(' ')] for header, value in zip(self.headers, values): pointcloud[header].append(value) for key in pointcloud.keys(): pointcloud[key] = np.asarray(pointcloud[key]) with processed_file_path.open('wb') as processed_file: pickle.dump(pointcloud, processed_file) return processed_file_path def __len__(self): raise NotImplementedError def __getitem__(self, item): processed_file_path = self._read_or_load(item) with processed_file_path.open('rb') as processed_file: pointcloud = pickle.load(processed_file) position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) label = pointcloud['label'] cl_label = pointcloud['cl_idx'] sample_idxs = self.sampling(position) return (normal[sample_idxs].astype(np.float), position[sample_idxs].astype(np.float), label[sample_idxs].astype(np.int), cl_label[sample_idxs].astype(np.int) )