From 444725f6af9acb268c825949a042af0054af0a4d Mon Sep 17 00:00:00 2001 From: Si11ium <steffen.illium@ifi.lmu.de> Date: Tue, 19 May 2020 17:15:01 +0200 Subject: [PATCH] Dataset for whole pointclouds with farthest point sampling _incomplete_ --- datasets/_point_dataset.py | 36 +++++++++++++++++++++++++++++ datasets/full_pointclouds.py | 45 ++++++++++++++++++++++++++++++++++++ datasets/grid_clustered.py | 6 +++++ datasets/prim_clustered.py | 8 +++++++ datasets/template_dataset.py | 13 ++++++++--- 5 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 datasets/_point_dataset.py create mode 100644 datasets/full_pointclouds.py create mode 100644 datasets/grid_clustered.py create mode 100644 datasets/prim_clustered.py diff --git a/datasets/_point_dataset.py b/datasets/_point_dataset.py new file mode 100644 index 0000000..79a7fc9 --- /dev/null +++ b/datasets/_point_dataset.py @@ -0,0 +1,36 @@ +from abc import ABC +from pathlib import Path + +from torch.utils.data import Dataset +from ml_lib.point_toolset.sampling import FarthestpointSampling + + +class _Point_Dataset(ABC, Dataset): + + @property + def setting(self) -> str: + raise NotImplementedError + + headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx'] + + def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs): + super(_Point_Dataset, self).__init__() + + self.load_preprocessed = load_preprocessed + self.transforms = transforms if transforms else lambda x: x + self.sampling_k = sampling_k + self.sampling = FarthestpointSampling(K=self.sampling_k) + self.root = Path(root) + self.raw = root / 'raw' + self.processed_ext = '.pik' + self.raw_ext = '.xyz' + self.processed = root / self.setting + + self._files = list(self.raw.glob(f'*{self.setting}*')) + + + def __len__(self): + raise NotImplementedError + + def __getitem__(self, item): + raise NotImplementedError diff --git a/datasets/full_pointclouds.py b/datasets/full_pointclouds.py new file mode 100644 index 0000000..f7f1b7d --- /dev/null +++ b/datasets/full_pointclouds.py @@ -0,0 +1,45 @@ +import pickle +from collections import defaultdict +from pathlib import Path + +import numpy as np +from torch.utils.data import Dataset + +from ._point_dataset import _Point_Dataset + + +class FullCloudsDataset(_Point_Dataset): + + setting = 'pc' + + def __init__(self, *args, **kwargs): + super(FullCloudsDataset, self).__init__(*args, **kwargs) + + def __len__(self): + return len(self._files) + + def __getitem__(self, item): + raw_file_path = self._files[item] + processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) + if not self.load_preprocessed: + processed_file_path.unlink(missing_ok=True) + if not processed_file_path.exists(): + pointcloud = defaultdict(list) + with raw_file_path.open('r') as raw_file: + for row in raw_file: + values = [float(x) for x in row.split(' ')] + for header, value in zip(self.headers, values): + pointcloud[header].append(value) + for key in pointcloud.keys(): + pointcloud[key] = np.asarray(pointcloud[key]) + with processed_file_path.open('wb') as processed_file: + pickle.dump(pointcloud, processed_file) + + with processed_file_path.open('rb') as processed_file: + pointcloud = pickle.load(processed_file) + points = np.stack(pointcloud['x'], pointcloud['y'], pointcloud['z']) + normal = np.stack(pointcloud['xn'], pointcloud['yn'], pointcloud['zn']) + label = points['label'] + samples = self.sampling(points) + + return points[samples], normal[samples], label[samples] diff --git a/datasets/grid_clustered.py b/datasets/grid_clustered.py new file mode 100644 index 0000000..e9e96ef --- /dev/null +++ b/datasets/grid_clustered.py @@ -0,0 +1,6 @@ +from torch.utils.data import Dataset + + +class TemplateDataset(_Point_Dataset): + def __init__(self, *args, **kwargs): + super(TemplateDataset, self).__init__() \ No newline at end of file diff --git a/datasets/prim_clustered.py b/datasets/prim_clustered.py new file mode 100644 index 0000000..612c520 --- /dev/null +++ b/datasets/prim_clustered.py @@ -0,0 +1,8 @@ +from torch.utils.data import Dataset + +from ._point_dataset import _Point_Dataset + + +class TemplateDataset(_Point_Dataset): + def __init__(self, *args, **kwargs): + super(TemplateDataset, self).__init__() \ No newline at end of file diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index 7f5a373..8318b5a 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -1,6 +1,13 @@ from torch.utils.data import Dataset +from._point_dataset import _Point_Dataset - -class TemplateDataset(Dataset): +class TemplateDataset(_Point_Dataset): def __init__(self, *args, **kwargs): - super(TemplateDataset, self).__init__() \ No newline at end of file + super(TemplateDataset, self).__init__() + + def __len__(self): + pass + + def __getitem__(self, item): + return item +