Grid Clusters.

This commit is contained in:
Si11ium
2020-06-07 16:47:52 +02:00
parent 8d0577b756
commit 2a767bead2
14 changed files with 278 additions and 149 deletions

View File

@@ -17,10 +17,6 @@ class _Point_Dataset(ABC, Dataset):
# FixMe: This does not work when more then x/y tuples are returned
return self[0][0].shape
@property
def setting(self) -> str:
raise NotImplementedError
headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx']
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
@@ -28,6 +24,8 @@ class _Point_Dataset(ABC, Dataset):
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
super(_Point_Dataset, self).__init__()
self.setting: str
self.dense_output = dense_output
self.split = split
self.norm_as_feature = norm_as_feature
@@ -67,4 +65,23 @@ class _Point_Dataset(ABC, Dataset):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
cl_label = pointcloud['cl_idx']
sample_idxs = self.sampling(position)
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
label[sample_idxs].astype(np.int),
cl_label[sample_idxs].astype(np.int)
)

View File

@@ -8,29 +8,11 @@ from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset):
setting = 'pc'
split: str
def __init__(self, *args, **kwargs):
def __init__(self, *args, setting='pc', **kwargs):
self.setting = setting
super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['label']
sample_idxs = self.sampling(position)
return (normal[sample_idxs].astype(np.float),
position[sample_idxs].astype(np.float),
label[sample_idxs].astype(np.int))

View File

@@ -1,32 +0,0 @@
import pickle
import numpy as np
from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset):
setting = 'grid'
def __init__(self, *args, **kwargs):
super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['cl_idx']
sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs]

View File

@@ -1,32 +0,0 @@
import pickle
import numpy as np
from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset):
setting = 'prim'
def __init__(self, *args, **kwargs):
super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
processed_file_path = self._read_or_load(item)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z'],
pointcloud['xn'], pointcloud['yn'], pointcloud['zn']
),
axis=-1)
# When yopu want to return points and normal seperately
# normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
label = pointcloud['cl_idx']
sample_idxs = self.sampling(points)
return points[sample_idxs], label[sample_idxs]

View File

@@ -1,6 +1,7 @@
from torch.utils.data import Dataset
from._point_dataset import _Point_Dataset
class TemplateDataset(_Point_Dataset):
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()