46 lines
1.6 KiB
Python
46 lines
1.6 KiB
Python
import pickle
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
from ._point_dataset import _Point_Dataset
|
|
|
|
|
|
class FullCloudsDataset(_Point_Dataset):
|
|
|
|
setting = 'pc'
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(FullCloudsDataset, self).__init__(*args, **kwargs)
|
|
|
|
def __len__(self):
|
|
return len(self._files)
|
|
|
|
def __getitem__(self, item):
|
|
raw_file_path = self._files[item]
|
|
processed_file_path = self.processed / raw_file_path.name.replace(self.raw_ext, self.processed_ext)
|
|
if not self.load_preprocessed:
|
|
processed_file_path.unlink(missing_ok=True)
|
|
if not processed_file_path.exists():
|
|
pointcloud = defaultdict(list)
|
|
with raw_file_path.open('r') as raw_file:
|
|
for row in raw_file:
|
|
values = [float(x) for x in row.split(' ')]
|
|
for header, value in zip(self.headers, values):
|
|
pointcloud[header].append(value)
|
|
for key in pointcloud.keys():
|
|
pointcloud[key] = np.asarray(pointcloud[key])
|
|
with processed_file_path.open('wb') as processed_file:
|
|
pickle.dump(pointcloud, processed_file)
|
|
|
|
with processed_file_path.open('rb') as processed_file:
|
|
pointcloud = pickle.load(processed_file)
|
|
points = np.stack(pointcloud['x'], pointcloud['y'], pointcloud['z'])
|
|
normal = np.stack(pointcloud['xn'], pointcloud['yn'], pointcloud['zn'])
|
|
label = points['label']
|
|
samples = self.sampling(points)
|
|
|
|
return points[samples], normal[samples], label[samples]
|