Now with included sorting per cloud variation
This commit is contained in:
@ -17,13 +17,15 @@ def save_names(name_list, path):
|
||||
with open(path, 'wb') as f:
|
||||
f.writelines(name_list)
|
||||
|
||||
|
||||
class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
|
||||
|
||||
def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None,
|
||||
headers=True, **kwargs):
|
||||
headers=True, has_variations=False):
|
||||
self.has_headers = headers
|
||||
self.has_variations = has_variations
|
||||
self.collate_per_element = collate_per_segment
|
||||
self.train = train
|
||||
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
|
||||
@ -88,6 +90,8 @@ class CustomShapeNet(InMemoryDataset):
|
||||
pass
|
||||
|
||||
for pointcloud in tqdm(os.scandir(path_to_clouds)):
|
||||
if self.has_variations:
|
||||
cloud_variations = defaultdict(list)
|
||||
if not os.path.isdir(pointcloud):
|
||||
continue
|
||||
data, paths = None, list()
|
||||
@ -131,8 +135,14 @@ class CustomShapeNet(InMemoryDataset):
|
||||
data = self._transform_and_filter(data)
|
||||
if self.collate_per_element:
|
||||
datasets[data_folder].append(data)
|
||||
if self.has_variations:
|
||||
cloud_variations[int(element.name.split('_')[0])].append(data)
|
||||
if not self.collate_per_element:
|
||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
if self.has_variations:
|
||||
for variation in cloud_variations.keys():
|
||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
else:
|
||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
|
||||
if datasets[data_folder]:
|
||||
os.makedirs(self.processed_dir, exist_ok=True)
|
||||
@ -147,11 +157,12 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
Resample raw point cloud to fixed number of points.
|
||||
Map raw label from range [1, N] to [0, N-1].
|
||||
"""
|
||||
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, npoints=1024, headers=True):
|
||||
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None,
|
||||
has_variations=False, npoints=1024, headers=True):
|
||||
super(ShapeNetPartSegDataset, self).__init__()
|
||||
self.npoints = npoints
|
||||
self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment,
|
||||
train=train, transform=transform, headers=headers)
|
||||
train=train, transform=transform, headers=headers, has_variations=has_variations)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.dataset[index]
|
||||
|
Reference in New Issue
Block a user