Now with included sorting per cloud variation

This commit is contained in:
Si11ium
2019-08-02 13:03:56 +02:00
parent 9c100b6c43
commit 43b4c8031a
2 changed files with 22 additions and 6 deletions

View File

@ -17,13 +17,15 @@ def save_names(name_list, path):
with open(path, 'wb') as f:
f.writelines(name_list)
class CustomShapeNet(InMemoryDataset):
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None,
headers=True, **kwargs):
headers=True, has_variations=False):
self.has_headers = headers
self.has_variations = has_variations
self.collate_per_element = collate_per_segment
self.train = train
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
@ -88,6 +90,8 @@ class CustomShapeNet(InMemoryDataset):
pass
for pointcloud in tqdm(os.scandir(path_to_clouds)):
if self.has_variations:
cloud_variations = defaultdict(list)
if not os.path.isdir(pointcloud):
continue
data, paths = None, list()
@ -131,8 +135,14 @@ class CustomShapeNet(InMemoryDataset):
data = self._transform_and_filter(data)
if self.collate_per_element:
datasets[data_folder].append(data)
if self.has_variations:
cloud_variations[int(element.name.split('_')[0])].append(data)
if not self.collate_per_element:
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if self.has_variations:
for variation in cloud_variations.keys():
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
else:
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[data_folder]:
os.makedirs(self.processed_dir, exist_ok=True)
@ -147,11 +157,12 @@ class ShapeNetPartSegDataset(Dataset):
Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1].
"""
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, npoints=1024, headers=True):
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None,
has_variations=False, npoints=1024, headers=True):
super(ShapeNetPartSegDataset, self).__init__()
self.npoints = npoints
self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment,
train=train, transform=transform, headers=headers)
train=train, transform=transform, headers=headers, has_variations=has_variations)
def __getitem__(self, index):
data = self.dataset[index]