84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
import pickle
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
from torch.utils.data import ConcatDataset
|
|
from tqdm import trange
|
|
|
|
from ._point_dataset import _Point_Dataset
|
|
|
|
|
|
class GridClusters(_Point_Dataset):
|
|
|
|
split: str
|
|
name = 'GridClusters'
|
|
|
|
def __init__(self, *args, n_spatial_clusters=3*3*3, setting='pc', **kwargs):
|
|
self.n_spatial_clusters = n_spatial_clusters
|
|
self.setting = setting
|
|
super(GridClusters, self).__init__(*args, **kwargs)
|
|
|
|
def __len__(self):
|
|
return len(self._files)
|
|
|
|
def _read_or_load(self, item):
|
|
raw_file_path = self._files[item]
|
|
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
|
|
|
|
if not self.load_preprocessed:
|
|
processed_file_path.unlink(missing_ok=True)
|
|
if not processed_file_path.exists():
|
|
# nested default dict
|
|
pointcloud = defaultdict(lambda: defaultdict(list))
|
|
|
|
with raw_file_path.open('r') as raw_file:
|
|
for row in raw_file:
|
|
values = [float(x) for x in row.strip().split(' ')]
|
|
for header, value in zip(self.headers, values):
|
|
pointcloud[int(values[-1])][header].append(value)
|
|
for cluster in pointcloud.keys():
|
|
for key in pointcloud[cluster].keys():
|
|
pointcloud[cluster][key] = np.asarray(pointcloud[cluster][key])
|
|
pointcloud[cluster] = dict(pointcloud[cluster])
|
|
pointcloud = dict(pointcloud)
|
|
|
|
with processed_file_path.open('wb') as processed_file:
|
|
pickle.dump(pointcloud, processed_file)
|
|
return processed_file_path
|
|
|
|
def __getitem__(self, item):
|
|
processed_file_path = self._read_or_load(item)
|
|
|
|
with processed_file_path.open('rb') as processed_file:
|
|
pointcloud = pickle.load(processed_file)
|
|
|
|
# By number Variant
|
|
# cl_idx_list = np.cumsum([[len(self) // self.n_spatial_clusters, ] * self.n_spatial_clusters])
|
|
# cl_idx = [idx for idx, x in enumerate(cl_idx_list) if item <= x][0]
|
|
|
|
# Random Variant
|
|
cl_idx = np.random.randint(0, len(pointcloud))
|
|
pointcloud = pointcloud[list(pointcloud.keys())[cl_idx]]
|
|
|
|
position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1)
|
|
|
|
normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1)
|
|
|
|
label = pointcloud['label']
|
|
|
|
cl_label = pointcloud['cl_idx']
|
|
|
|
sample_idxs = self.sampling(position)
|
|
while sample_idxs.shape[0] < self.sampling_k:
|
|
sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k]
|
|
|
|
normal = normal[sample_idxs].astype(np.float)
|
|
position = position[sample_idxs].astype(np.float)
|
|
|
|
normal = self.transforms(normal)
|
|
position = self.transforms(position)
|
|
return (normal, position,
|
|
label[sample_idxs].astype(np.int),
|
|
cl_label[sample_idxs].astype(np.int)
|
|
)
|