2020-06-23 21:05:49 +02:00

257 lines
9.0 KiB
Python

from pathlib import Path
from typing import Union
from warnings import warn
from collections import defaultdict
import os
from torch.utils.data import Dataset
from tqdm import tqdm
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from utils.project_settings import Classes, DataSplit, ClusterTypes
def save_names(name_list, path):
with open(path, 'wb') as f:
f.writelines(name_list)
class CustomShapeNet(InMemoryDataset):
name = 'CustomShapeNet'
def download(self):
pass
@property
def categories(self):
return {key: val for val, key in self.classes.items()}
@property
def modes(self):
return {key: val for val, key in DataSplit().items()}
@property
def cluster_types(self):
return {key: val for val, key in ClusterTypes().items()}
@property
def raw_dir(self):
return self.root / 'raw'
@property
def raw_file_names(self):
return [self.mode]
@property
def processed_dir(self):
return self.root / 'processed'
def __init__(self, root_dir, collate_per_segment=True, mode='train', transform=None, pre_filter=None,
pre_transform=None, refresh=False, cluster_type: Union[str, None] = '',
poly_as_plane=False):
assert mode in self.modes.keys(), \
f'"mode" must be one of {self.modes.keys()}'
assert cluster_type in self.cluster_types.keys() or cluster_type is None, \
f'"cluster_type" must be one of {self.cluster_types.keys()} or None, but was: {cluster_type}'
# Set the Dataset Parameters
self.cluster_type = cluster_type if cluster_type else 'pc'
self.classes = Classes()
self.poly_as_plane = poly_as_plane
self.collate_per_segment = collate_per_segment
self.mode = mode
self.refresh = refresh
root_dir = Path(root_dir)
super(CustomShapeNet, self).__init__(root_dir, transform, pre_transform, pre_filter)
self.data, self.slices = self._load_dataset()
print("Initialized")
@property
def processed_file_names(self):
return [f'{self.mode}.pt']
def check_and_resolve_cloud_count(self):
if (self.raw_dir / self.mode).exists():
file_count = len([cloud for cloud in (self.raw_dir / self.mode).iterdir() if cloud.is_file()])
if file_count:
print(f'{file_count} files have been found....')
return file_count
else:
warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
return file_count
warn(ResourceWarning("The raw data folder does not exist. Was this intentional?"))
return -1
@property
def num_classes(self):
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
def _load_dataset(self):
data, slices = None, None
filepath = self.processed_paths[0]
if self.refresh:
try:
os.remove(filepath)
print('Processed Location "Refreshed" (We deleted the Files)')
except FileNotFoundError:
print('You meant to refresh the allready processed dataset, but there were none...')
print('continue processing')
pass
while True:
try:
data, slices = torch.load(filepath)
print('Dataset Loaded')
break
except FileNotFoundError:
status = self.check_and_resolve_cloud_count()
if status in [0, -1]:
print(f'No dataset was loaded, status: {status}')
break
self.process()
continue
return data, slices
def _pre_transform_and_filter(self, data):
if self.pre_filter is not None and not self.pre_filter(data):
data = self.pre_filter(data)
if self.pre_transform is not None:
data = self.pre_transform(data)
return data
def process(self, delimiter=' '):
datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode
for pointcloud in tqdm(path_to_clouds.glob('*.xyz')):
if self.cluster_type not in pointcloud.name:
continue
data = None
with pointcloud.open('r') as f:
src = defaultdict(list)
# Iterate over all rows
for row in f:
if row != '':
vals = row.rstrip().split(delimiter)[None:None]
vals = [float(x) if x not in ['-nan(ind)', 'nan(ind)'] else 0 for x in vals]
src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict
src = dict(src)
# Transform the Dict[List] to Dict[torch.Tensor]
for key, values in src.items():
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
# Screw the Sorting and make it a FullCloud rather than a seperated
if not self.collate_per_segment:
src = dict(
all=torch.cat(tuple(src.values()))
)
# Transform Box and Polytope to Plane if poly_as_plane is set
for key, tensor in src.items():
if tensor.ndim == 1:
if all([x == 0 for x in tensor]):
continue
tensor = tensor.unsqueeze(0)
if self.poly_as_plane:
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
src[key] = tensor
for key, values in src.items():
try:
points = values[:, :-2]
except IndexError:
continue
y = torch.as_tensor(values[:, -2], dtype=torch.long)
y_c = torch.as_tensor(values[:, -1], dtype=torch.long)
####################################
# This is where you define the keys
attr_dict = dict(
y=y,
y_c=y_c,
pos=points[:, :3],
norm=points[:, 3:6]
)
####################################
if self.collate_per_segment:
data = Data(**attr_dict)
else:
if data is None:
data = defaultdict(list)
for attr_key, val in attr_dict.items():
data[attr_key].append(val)
# data = self._pre_transform_and_filter(data)
if self.collate_per_segment:
datasets[self.mode].append(data)
if not self.collate_per_segment:
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[self.mode]:
os.makedirs(self.processed_dir, exist_ok=True)
collated_dataset = self.collate(datasets[self.mode])
torch.save(collated_dataset, self.processed_paths[0])
def __repr__(self):
return f'{self.__class__.__name__}({len(self)})'
class ShapeNetPartSegDataset(Dataset):
"""
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
name = 'ShapeNetPartSegDataset'
def __init__(self, root_dir, mode='train', **kwargs):
super(ShapeNetPartSegDataset, self).__init__()
self.mode = mode
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
# self.npoints = npoints
self.dataset = CustomShapeNet(**kwargs)
def __getitem__(self, index):
data = self.dataset[index]
# Resample to fixed number of points
'''
try:
npoints = self.npoints if self.mode != DataSplit.predict else data.pos.shape[0]
choice = np.random.choice(data.pos.shape[0], npoints,
replace=False if self.mode == DataSplit.predict else True
)
except ValueError:
choice = []
pos, norm, y = data.pos[choice, :], data.norm[choice], data.y[choice]
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
y=y, # torch.Tensor (n,)
norm=norm # torch.Tensor (n, 3/0)
)
)
'''
return data
def __len__(self):
return len(self.dataset)
def num_classes(self):
return self.dataset.num_classes