import pickle from collections import defaultdict import numpy as np from torch.utils.data import ConcatDataset from tqdm import trange from ._point_dataset import _Point_Dataset class GridClusters(_Point_Dataset): split: str name = 'GridClusters' def __init__(self, *args, n_spatial_clusters=3*3*3, setting='pc', **kwargs): self.n_spatial_clusters = n_spatial_clusters self.setting = setting super(GridClusters, self).__init__(*args, **kwargs) def __len__(self): return len(self._files) def _read_or_load(self, item): raw_file_path = self._files[item] processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext) if not self.load_preprocessed: processed_file_path.unlink(missing_ok=True) if not processed_file_path.exists(): # nested default dict pointcloud = defaultdict(lambda: defaultdict(list)) with raw_file_path.open('r') as raw_file: for row in raw_file: values = [float(x) for x in row.strip().split(' ')] for header, value in zip(self.headers, values): pointcloud[int(values[-1])][header].append(value) for cluster in pointcloud.keys(): for key in pointcloud[cluster].keys(): pointcloud[cluster][key] = np.asarray(pointcloud[cluster][key]) pointcloud[cluster] = dict(pointcloud[cluster]) pointcloud = dict(pointcloud) with processed_file_path.open('wb') as processed_file: pickle.dump(pointcloud, processed_file) return processed_file_path def __getitem__(self, item): processed_file_path = self._read_or_load(item) with processed_file_path.open('rb') as processed_file: pointcloud = pickle.load(processed_file) # By number Variant # cl_idx_list = np.cumsum([[len(self) // self.n_spatial_clusters, ] * self.n_spatial_clusters]) # cl_idx = [idx for idx, x in enumerate(cl_idx_list) if item <= x][0] # Random Variant cl_idx = np.random.randint(0, len(pointcloud)) pointcloud = pointcloud[list(pointcloud.keys())[cl_idx]] position = np.stack((pointcloud['x'], pointcloud['y'], pointcloud['z']), axis=-1) normal = np.stack((pointcloud['xn'], pointcloud['yn'], pointcloud['zn']), axis=-1) label = pointcloud['label'] cl_label = pointcloud['cl_idx'] sample_idxs = self.sampling(position) while sample_idxs.shape[0] < self.sampling_k: sample_idxs = np.concatenate((sample_idxs, sample_idxs))[:self.sampling_k] return (normal[sample_idxs].astype(np.float), position[sample_idxs].astype(np.float), label[sample_idxs].astype(np.int), cl_label[sample_idxs].astype(np.int) )