229 lines
8.0 KiB
Python
229 lines
8.0 KiB
Python
from pathlib import Path
|
|
from warnings import warn
|
|
|
|
import numpy as np
|
|
|
|
from collections import defaultdict
|
|
|
|
import os
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
from torch_geometric.data import InMemoryDataset
|
|
from torch_geometric.data import Data
|
|
|
|
from utils.project_settings import Classes, DataSplit
|
|
|
|
|
|
def save_names(name_list, path):
|
|
with open(path, 'wb') as f:
|
|
f.writelines(name_list)
|
|
|
|
|
|
class CustomShapeNet(InMemoryDataset):
|
|
|
|
categories = {key: val for val, key in Classes().items()}
|
|
modes = {key: val for val, key in DataSplit().items()}
|
|
name = 'CustomShapeNet'
|
|
|
|
@property
|
|
def raw_dir(self):
|
|
return self.root / 'raw'
|
|
|
|
@property
|
|
def raw_file_names(self):
|
|
return [self.mode]
|
|
|
|
@property
|
|
def processed_dir(self):
|
|
return self.root / 'processed'
|
|
|
|
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
|
|
pre_transform=None, refresh=False, with_normals=False):
|
|
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
|
|
|
|
# Set the Dataset Parameters
|
|
self.collate_per_segment = collate_per_segment
|
|
self.mode = mode
|
|
self.refresh = refresh
|
|
self.with_normals = with_normals
|
|
root_dir = Path(root_dir)
|
|
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
|
|
self.data, self.slices = self._load_dataset()
|
|
print("Initialized")
|
|
|
|
@property
|
|
def processed_file_names(self):
|
|
return [f'{self.mode}.pt']
|
|
|
|
def check_and_resolve_cloud_count(self):
|
|
if (self.raw_dir / self.mode).exists():
|
|
file_count = len([cloud for cloud in (self.raw_dir / self.mode).iterdir() if cloud.is_file()])
|
|
|
|
if file_count:
|
|
print(f'{file_count} files have been found....')
|
|
return file_count
|
|
else:
|
|
warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
|
|
return file_count
|
|
warn(ResourceWarning("The raw data folder does not exist. Was this intentional?"))
|
|
return -1
|
|
|
|
@property
|
|
def num_classes(self):
|
|
return len(self.categories)
|
|
|
|
def _load_dataset(self):
|
|
data, slices = None, None
|
|
filepath = self.processed_paths[0]
|
|
if self.refresh:
|
|
try:
|
|
os.remove(filepath)
|
|
print('Processed Location "Refreshed" (We deleted the Files)')
|
|
except FileNotFoundError:
|
|
print('You meant to refresh the allready processed dataset, but there were none...')
|
|
print('continue processing')
|
|
pass
|
|
|
|
while True:
|
|
try:
|
|
data, slices = torch.load(filepath)
|
|
print('Dataset Loaded')
|
|
break
|
|
except FileNotFoundError:
|
|
status = self.check_and_resolve_cloud_count()
|
|
if status in [0, -1]:
|
|
print(f'No dataset was loaded, status: {status}')
|
|
break
|
|
self.process()
|
|
continue
|
|
return data, slices
|
|
|
|
def _pre_transform_and_filter(self, data):
|
|
# ToDo: ANy filter to apply? Then do it here.
|
|
if self.pre_filter is not None and not self.pre_filter(data):
|
|
data = self.pre_filter(data)
|
|
raise NotImplementedError
|
|
# ToDo: ANy transformation to apply? Then do it here.
|
|
if self.pre_transform is not None:
|
|
data = self.pre_transform(data)
|
|
raise NotImplementedError
|
|
return data
|
|
|
|
def process(self, delimiter=' '):
|
|
datasets = defaultdict(list)
|
|
path_to_clouds = self.raw_dir / self.mode
|
|
|
|
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
|
|
if 'grid' not in pointcloud.name:
|
|
continue
|
|
data = None
|
|
|
|
with pointcloud.open('r') as f:
|
|
src = defaultdict(list)
|
|
# Iterate over all rows
|
|
for row in f:
|
|
if row != '':
|
|
vals = row.rstrip().split(delimiter)[None:None]
|
|
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
|
|
src[vals[-1]].append(vals)
|
|
|
|
src = dict(src)
|
|
for key, values in src.items():
|
|
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
|
|
|
if not self.collate_per_segment:
|
|
src = dict(
|
|
all=torch.cat(tuple(src.values()))
|
|
)
|
|
|
|
for key, values in src.items():
|
|
try:
|
|
points = values[:, :-2]
|
|
except IndexError:
|
|
continue
|
|
y = torch.as_tensor(values[:, -2], dtype=torch.long)
|
|
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
|
|
####################################
|
|
# This is where you define the keys
|
|
attr_dict = dict(y=y, y_c=y_c)
|
|
if self.with_normals:
|
|
pos = points[:, :6]
|
|
norm = None
|
|
attr_dict.update(pos=pos, norm=norm)
|
|
if not self.with_normals:
|
|
pos = points[:, :3]
|
|
norm = points[:, 3:6]
|
|
attr_dict.update(pos=pos, norm=norm)
|
|
####################################
|
|
if self.collate_per_segment:
|
|
data = Data(**attr_dict)
|
|
else:
|
|
if data is None:
|
|
data = defaultdict(list)
|
|
# points=points, norm=points[:, 3:]
|
|
for key, val in attr_dict.items():
|
|
data[key].append(val)
|
|
# data = Data(**data)
|
|
|
|
# data = self._pre_transform_and_filter(data)
|
|
if self.collate_per_segment:
|
|
datasets[self.mode].append(data)
|
|
if not self.collate_per_segment:
|
|
# This is just to be sure, but should not be needed, since src[all] == all there is in this cloud
|
|
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
|
|
|
if datasets[self.mode]:
|
|
os.makedirs(self.processed_dir, exist_ok=True)
|
|
torch.save(self.collate(datasets[self.mode]), self.processed_paths[0])
|
|
|
|
def __repr__(self):
|
|
return f'{self.__class__.__name__}({len(self)})'
|
|
|
|
|
|
class ShapeNetPartSegDataset(Dataset):
|
|
"""
|
|
Resample raw point cloud to fixed number of points.
|
|
Map raw label from range [1, N] to [0, N-1].
|
|
"""
|
|
|
|
name = 'ShapeNetPartSegDataset'
|
|
|
|
def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
|
|
super(ShapeNetPartSegDataset, self).__init__()
|
|
self.mode = mode
|
|
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
|
self.npoints = npoints
|
|
self.dataset = CustomShapeNet(**kwargs)
|
|
|
|
def __getitem__(self, index):
|
|
data = self.dataset[index]
|
|
|
|
# Resample to fixed number of points
|
|
try:
|
|
npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0]
|
|
choice = np.random.choice(data.pos.shape[0], npoints,
|
|
replace=False if self.mode == DataSplit.predict else True
|
|
)
|
|
except ValueError:
|
|
choice = []
|
|
|
|
pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice]
|
|
|
|
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
|
|
|
|
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
|
|
y=y, # torch.Tensor (n,)
|
|
norm=norm # torch.Tensor (n, 3/0)
|
|
)
|
|
)
|
|
|
|
return data
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def num_classes(self):
|
|
return self.dataset.num_classes
|