From 444725f6af9acb268c825949a042af0054af0a4d Mon Sep 17 00:00:00 2001
From: Si11ium <steffen.illium@ifi.lmu.de>
Date: Tue, 19 May 2020 17:15:01 +0200
Subject: [PATCH] Dataset for whole pointclouds with farthest point sampling
 _incomplete_

---
 datasets/_point_dataset.py   | 36 +++++++++++++++++++++++++++++
 datasets/full_pointclouds.py | 45 ++++++++++++++++++++++++++++++++++++
 datasets/grid_clustered.py   |  6 +++++
 datasets/prim_clustered.py   |  8 +++++++
 datasets/template_dataset.py | 13 ++++++++---
 5 files changed, 105 insertions(+), 3 deletions(-)
 create mode 100644 datasets/_point_dataset.py
 create mode 100644 datasets/full_pointclouds.py
 create mode 100644 datasets/grid_clustered.py
 create mode 100644 datasets/prim_clustered.py

diff --git a/datasets/_point_dataset.py b/datasets/_point_dataset.py
new file mode 100644
index 0000000..79a7fc9
--- /dev/null
+++ b/datasets/_point_dataset.py
@@ -0,0 +1,36 @@
+from abc import ABC
+from pathlib import Path
+
+from torch.utils.data import Dataset
+from ml_lib.point_toolset.sampling import FarthestpointSampling
+
+
+class _Point_Dataset(ABC, Dataset):
+
+    @property
+    def setting(self) -> str:
+        raise NotImplementedError
+
+    headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx']
+
+    def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs):
+        super(_Point_Dataset, self).__init__()
+
+        self.load_preprocessed = load_preprocessed
+        self.transforms = transforms if transforms else lambda x: x
+        self.sampling_k = sampling_k
+        self.sampling = FarthestpointSampling(K=self.sampling_k)
+        self.root = Path(root)
+        self.raw = root / 'raw'
+        self.processed_ext = '.pik'
+        self.raw_ext = '.xyz'
+        self.processed = root / self.setting
+
+        self._files = list(self.raw.glob(f'*{self.setting}*'))
+
+
+    def __len__(self):
+        raise NotImplementedError
+
+    def __getitem__(self, item):
+        raise NotImplementedError
diff --git a/datasets/full_pointclouds.py b/datasets/full_pointclouds.py
new file mode 100644
index 0000000..f7f1b7d
--- /dev/null
+++ b/datasets/full_pointclouds.py
@@ -0,0 +1,45 @@
+import pickle
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from ._point_dataset import _Point_Dataset
+
+
+class FullCloudsDataset(_Point_Dataset):
+
+    setting = 'pc'
+
+    def __init__(self, *args, **kwargs):
+        super(FullCloudsDataset, self).__init__(*args, **kwargs)
+
+    def __len__(self):
+        return len(self._files)
+
+    def __getitem__(self, item):
+        raw_file_path = self._files[item]
+        processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
+        if not self.load_preprocessed:
+            processed_file_path.unlink(missing_ok=True)
+        if not processed_file_path.exists():
+            pointcloud = defaultdict(list)
+            with raw_file_path.open('r') as raw_file:
+                for row in raw_file:
+                    values = [float(x) for x in row.split(' ')]
+                    for header, value in zip(self.headers, values):
+                        pointcloud[header].append(value)
+            for key in pointcloud.keys():
+                pointcloud[key] = np.asarray(pointcloud[key])
+            with processed_file_path.open('wb') as processed_file:
+                pickle.dump(pointcloud, processed_file)
+
+        with processed_file_path.open('rb') as processed_file:
+            pointcloud = pickle.load(processed_file)
+        points = np.stack(pointcloud['x'], pointcloud['y'], pointcloud['z'])
+        normal = np.stack(pointcloud['xn'], pointcloud['yn'], pointcloud['zn'])
+        label = points['label']
+        samples = self.sampling(points)
+
+        return points[samples], normal[samples], label[samples]
diff --git a/datasets/grid_clustered.py b/datasets/grid_clustered.py
new file mode 100644
index 0000000..e9e96ef
--- /dev/null
+++ b/datasets/grid_clustered.py
@@ -0,0 +1,6 @@
+from torch.utils.data import Dataset
+
+
+class TemplateDataset(_Point_Dataset):
+    def __init__(self, *args, **kwargs):
+        super(TemplateDataset, self).__init__()
\ No newline at end of file
diff --git a/datasets/prim_clustered.py b/datasets/prim_clustered.py
new file mode 100644
index 0000000..612c520
--- /dev/null
+++ b/datasets/prim_clustered.py
@@ -0,0 +1,8 @@
+from torch.utils.data import Dataset
+
+from ._point_dataset import _Point_Dataset
+
+
+class TemplateDataset(_Point_Dataset):
+    def __init__(self, *args, **kwargs):
+        super(TemplateDataset, self).__init__()
\ No newline at end of file
diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py
index 7f5a373..8318b5a 100644
--- a/datasets/template_dataset.py
+++ b/datasets/template_dataset.py
@@ -1,6 +1,13 @@
 from torch.utils.data import Dataset
+from._point_dataset import _Point_Dataset
 
-
-class TemplateDataset(Dataset):
+class TemplateDataset(_Point_Dataset):
     def __init__(self, *args, **kwargs):
-        super(TemplateDataset, self).__init__()
\ No newline at end of file
+        super(TemplateDataset, self).__init__()
+
+    def __len__(self):
+        pass
+
+    def __getitem__(self, item):
+        return item
+