37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
from abc import ABC
|
|
from pathlib import Path
|
|
|
|
from torch.utils.data import Dataset
|
|
from ml_lib.point_toolset.sampling import FarthestpointSampling
|
|
|
|
|
|
class _Point_Dataset(ABC, Dataset):
|
|
|
|
@property
|
|
def setting(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx']
|
|
|
|
def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs):
|
|
super(_Point_Dataset, self).__init__()
|
|
|
|
self.load_preprocessed = load_preprocessed
|
|
self.transforms = transforms if transforms else lambda x: x
|
|
self.sampling_k = sampling_k
|
|
self.sampling = FarthestpointSampling(K=self.sampling_k)
|
|
self.root = Path(root)
|
|
self.raw = root / 'raw'
|
|
self.processed_ext = '.pik'
|
|
self.raw_ext = '.xyz'
|
|
self.processed = root / self.setting
|
|
|
|
self._files = list(self.raw.glob(f'*{self.setting}*'))
|
|
|
|
|
|
def __len__(self):
|
|
raise NotImplementedError
|
|
|
|
def __getitem__(self, item):
|
|
raise NotImplementedError
|