Dataset for whole pointclouds with farthest point sampling _incomplete_

This commit is contained in:
Si11ium 2020-05-19 17:15:01 +02:00
parent 9ccbec9d7c
commit 444725f6af
5 changed files with 105 additions and 3 deletions

View File

@ -0,0 +1,36 @@
from abc import ABC
from pathlib import Path
from torch.utils.data import Dataset
from ml_lib.point_toolset.sampling import FarthestpointSampling
class _Point_Dataset(ABC, Dataset):
@property
def setting(self) -> str:
raise NotImplementedError
headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx']
def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs):
super(_Point_Dataset, self).__init__()
self.load_preprocessed = load_preprocessed
self.transforms = transforms if transforms else lambda x: x
self.sampling_k = sampling_k
self.sampling = FarthestpointSampling(K=self.sampling_k)
self.root = Path(root)
self.raw = root / 'raw'
self.processed_ext = '.pik'
self.raw_ext = '.xyz'
self.processed = root / self.setting
self._files = list(self.raw.glob(f'*{self.setting}*'))
def __len__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError

View File

@ -0,0 +1,45 @@
import pickle
from collections import defaultdict
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset
from ._point_dataset import _Point_Dataset
class FullCloudsDataset(_Point_Dataset):
setting = 'pc'
def __init__(self, *args, **kwargs):
super(FullCloudsDataset, self).__init__(*args, **kwargs)
def __len__(self):
return len(self._files)
def __getitem__(self, item):
raw_file_path = self._files[item]
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
if not self.load_preprocessed:
processed_file_path.unlink(missing_ok=True)
if not processed_file_path.exists():
pointcloud = defaultdict(list)
with raw_file_path.open('r') as raw_file:
for row in raw_file:
values = [float(x) for x in row.split(' ')]
for header, value in zip(self.headers, values):
pointcloud[header].append(value)
for key in pointcloud.keys():
pointcloud[key] = np.asarray(pointcloud[key])
with processed_file_path.open('wb') as processed_file:
pickle.dump(pointcloud, processed_file)
with processed_file_path.open('rb') as processed_file:
pointcloud = pickle.load(processed_file)
points = np.stack(pointcloud['x'], pointcloud['y'], pointcloud['z'])
normal = np.stack(pointcloud['xn'], pointcloud['yn'], pointcloud['zn'])
label = points['label']
samples = self.sampling(points)
return points[samples], normal[samples], label[samples]

View File

@ -0,0 +1,6 @@
from torch.utils.data import Dataset
class TemplateDataset(_Point_Dataset):
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()

View File

@ -0,0 +1,8 @@
from torch.utils.data import Dataset
from ._point_dataset import _Point_Dataset
class TemplateDataset(_Point_Dataset):
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()

View File

@ -1,6 +1,13 @@
from torch.utils.data import Dataset
from._point_dataset import _Point_Dataset
class TemplateDataset(Dataset):
class TemplateDataset(_Point_Dataset):
def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__()
def __len__(self):
pass
def __getitem__(self, item):
return item