92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
import pickle
|
|
from collections import defaultdict
|
|
|
|
from abc import ABC
|
|
from pathlib import Path
|
|
|
|
from torch.utils.data import Dataset, ConcatDataset
|
|
from ml_lib.point_toolset.sampling import FarthestpointSampling, RandomSampling
|
|
|
|
import numpy as np
|
|
|
|
|
|
class _Point_Dataset(ABC, Dataset):
|
|
|
|
@property
|
|
def name(self):
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def sample_shape(self):
|
|
# FixMe: This does not work when more then x/y tuples are returned
|
|
return self[0][0].shape
|
|
|
|
headers = ['x', 'y', 'z', 'xn', 'yn', 'zn', 'label', 'cl_idx']
|
|
samplers = dict(fps=FarthestpointSampling, rnd=RandomSampling)
|
|
|
|
def __init__(self, root=Path('data'), norm_as_feature=True, sampling_k=2048, sampling='rnd',
|
|
transforms=None, load_preprocessed=True, split='train', dense_output=False, *args, **kwargs):
|
|
super(_Point_Dataset, self).__init__()
|
|
|
|
self.setting: str
|
|
|
|
self.dense_output = dense_output
|
|
self.split = split
|
|
self.norm_as_feature = norm_as_feature
|
|
self.load_preprocessed = load_preprocessed
|
|
self.transforms = transforms if transforms else lambda x: x
|
|
self.sampling_k = sampling_k
|
|
self.sampling = self.samplers[sampling](K=self.sampling_k)
|
|
self.root = Path(root)
|
|
self.raw = self.root / 'raw' / self.split
|
|
self.processed_ext = '.pik'
|
|
self.raw_ext = '.xyz'
|
|
self.processed = self.root / self.setting
|
|
self.processed.mkdir(parents=True, exist_ok=True)
|
|
|
|
self._files = list(self.raw.glob(f'*{self.setting}*'))
|
|
|
|
def _read_or_load(self, item):
|
|
raw_file_path = self._files[item]
|
|
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
|
|
|
|
if not self.load_preprocessed:
|
|
processed_file_path.unlink(missing_ok=True)
|
|
if not processed_file_path.exists():
|
|
pointcloud = defaultdict(list)
|
|
with raw_file_path.open('r') as raw_file:
|
|
for row in raw_file:
|
|
values = [float(x) for x in row.strip().split(' ')]
|
|
for header, value in zip(self.headers, values):
|
|
pointcloud[header].append(value)
|
|
for key in pointcloud.keys():
|
|
pointcloud[key] = np.asarray(pointcloud[key])
|
|
with processed_file_path.open('wb') as processed_file:
|
|
pickle.dump(pointcloud, processed_file)
|
|
return processed_file_path
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def __getitem__(self, item):
|
|
processed_file_path = self._read_or_load(item)
|
|
|
|
with processed_file_path.open('rb') as processed_file:
|
|
pointcloud = pickle.load(processed_file)
|
|
|
|
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
|
|
|
|
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
|
|
|
|
label = pointcloud['label']
|
|
|
|
cl_label = pointcloud['cl_idx']
|
|
|
|
sample_idxs = self.sampling(position)
|
|
|
|
return (normal[sample_idxs].astype(np.float),
|
|
position[sample_idxs].astype(np.float),
|
|
label[sample_idxs].astype(np.int),
|
|
cl_label[sample_idxs].astype(np.int)
|
|
)
|