2020-06-19 19:00:07 +02:00

229 lines
8.0 KiB
Python

from pathlib import Path
from warnings import warn
import numpy as np
from collections import defaultdict
import os
from torch.utils.data import Dataset
from tqdm import tqdm
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from utils.project_settings import Classes, DataSplit
def save_names(name_list, path):
with open(path, 'wb') as f:
f.writelines(name_list)
class CustomShapeNet(InMemoryDataset):
categories = {key: val for val, key in Classes().items()}
modes = {key: val for val, key in DataSplit().items()}
name = 'CustomShapeNet'
@property
def raw_dir(self):
return self.root / 'raw'
@property
def raw_file_names(self):
return [self.mode]
@property
def processed_dir(self):
return self.root / 'processed'
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
pre_transform=None, refresh=False, with_normals=False):
assert mode in self.modes.keys(), f'"mode" must be one of {self.modes.keys()}'
# Set the Dataset Parameters
self.collate_per_segment = collate_per_segment
self.mode = mode
self.refresh = refresh
self.with_normals = with_normals
root_dir = Path(root_dir)
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
self.data, self.slices = self._load_dataset()
print("Initialized")
@property
def processed_file_names(self):
return [f'{self.mode}.pt']
def check_and_resolve_cloud_count(self):
if (self.raw_dir / self.mode).exists():
file_count = len([cloud for cloud in (self.raw_dir / self.mode).iterdir() if cloud.is_file()])
if file_count:
print(f'{file_count} files have been found....')
return file_count
else:
warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
return file_count
warn(ResourceWarning("The raw data folder does not exist. Was this intentional?"))
return -1
@property
def num_classes(self):
return len(self.categories)
def _load_dataset(self):
data, slices = None, None
filepath = self.processed_paths[0]
if self.refresh:
try:
os.remove(filepath)
print('Processed Location "Refreshed" (We deleted the Files)')
except FileNotFoundError:
print('You meant to refresh the allready processed dataset, but there were none...')
print('continue processing')
pass
while True:
try:
data, slices = torch.load(filepath)
print('Dataset Loaded')
break
except FileNotFoundError:
status = self.check_and_resolve_cloud_count()
if status in [0, -1]:
print(f'No dataset was loaded, status: {status}')
break
self.process()
continue
return data, slices
def _pre_transform_and_filter(self, data):
# ToDo: ANy filter to apply? Then do it here.
if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data)
raise NotImplementedError
# ToDo: ANy transformation to apply? Then do it here.
if self.pre_transform is not None:
data = self.pre_transform(data)
raise NotImplementedError
return data
def process(self, delimiter=' '):
datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
if 'grid' not in pointcloud.name:
continue
data = None
with pointcloud.open('r') as f:
src = defaultdict(list)
# Iterate over all rows
for row in f:
if row != '':
vals = row.rstrip().split(delimiter)[None:None]
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
src[vals[-1]].append(vals)
src = dict(src)
for key, values in src.items():
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
if not self.collate_per_segment:
src = dict(
all=torch.cat(tuple(src.values()))
)
for key, values in src.items():
try:
points = values[:, :-2]
except IndexError:
continue
y = torch.as_tensor(values[:, -2], dtype=torch.long)
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
####################################
# This is where you define the keys
attr_dict = dict(y=y, y_c=y_c)
if self.with_normals:
pos = points[:, :6]
norm = None
attr_dict.update(pos=pos, norm=norm)
if not self.with_normals:
pos = points[:, :3]
norm = points[:, 3:6]
attr_dict.update(pos=pos, norm=norm)
####################################
if self.collate_per_segment:
data = Data(**attr_dict)
else:
if data is None:
data = defaultdict(list)
# points=points, norm=points[:, 3:]
for key, val in attr_dict.items():
data[key].append(val)
# data = Data(**data)
# data = self._pre_transform_and_filter(data)
if self.collate_per_segment:
datasets[self.mode].append(data)
if not self.collate_per_segment:
# This is just to be sure, but should not be needed, since src[all] == all there is in this cloud
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[self.mode]:
os.makedirs(self.processed_dir, exist_ok=True)
torch.save(self.collate(datasets[self.mode]), self.processed_paths[0])
def __repr__(self):
return f'{self.__class__.__name__}({len(self)})'
class ShapeNetPartSegDataset(Dataset):
"""
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
name = 'ShapeNetPartSegDataset'
def __init__(self, root_dir, npoints=1024, mode='train', **kwargs):
super(ShapeNetPartSegDataset, self).__init__()
self.mode = mode
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
self.npoints = npoints
self.dataset = CustomShapeNet(**kwargs)
def __getitem__(self, index):
data = self.dataset[index]
# Resample to fixed number of points
try:
npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0]
choice = np.random.choice(data.pos.shape[0], npoints,
replace=False if self.mode == DataSplit.predict else True
)
except ValueError:
choice = []
pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice]
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
y=y, # torch.Tensor (n,)
norm=norm # torch.Tensor (n, 3/0)
)
)
return data
def __len__(self):
return len(self.dataset)
def num_classes(self):
return self.dataset.num_classes