from abc import ABC from pathlib import Path from torch.utils.data import Dataset from ml_lib.point_toolset.sampling import FarthestpointSampling class _Point_Dataset(ABC, Dataset): @property def setting(self) -> str: raise NotImplementedError headers = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'label', 'cl_idx'] def __init__(self, root=Path('data'), sampling_k=2048, transforms=None, load_preprocessed=True, *args, **kwargs): super(_Point_Dataset, self).__init__() self.load_preprocessed = load_preprocessed self.transforms = transforms if transforms else lambda x: x self.sampling_k = sampling_k self.sampling = FarthestpointSampling(K=self.sampling_k) self.root = Path(root) self.raw = root / 'raw' self.processed_ext = '.pik' self.raw_ext = '.xyz' self.processed = root / self.setting self._files = list(self.raw.glob(f'*{self.setting}*')) def __len__(self): raise NotImplementedError def __getitem__(self, item): raise NotImplementedError