Now with included sorting per cloud variation
This commit is contained in:
parent
9c100b6c43
commit
43b4c8031a
@ -17,13 +17,15 @@ def save_names(name_list, path):
|
|||||||
with open(path, 'wb') as f:
|
with open(path, 'wb') as f:
|
||||||
f.writelines(name_list)
|
f.writelines(name_list)
|
||||||
|
|
||||||
|
|
||||||
class CustomShapeNet(InMemoryDataset):
|
class CustomShapeNet(InMemoryDataset):
|
||||||
|
|
||||||
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
|
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
|
||||||
|
|
||||||
def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None,
|
def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None,
|
||||||
headers=True, **kwargs):
|
headers=True, has_variations=False):
|
||||||
self.has_headers = headers
|
self.has_headers = headers
|
||||||
|
self.has_variations = has_variations
|
||||||
self.collate_per_element = collate_per_segment
|
self.collate_per_element = collate_per_segment
|
||||||
self.train = train
|
self.train = train
|
||||||
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
|
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
|
||||||
@ -88,6 +90,8 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
for pointcloud in tqdm(os.scandir(path_to_clouds)):
|
for pointcloud in tqdm(os.scandir(path_to_clouds)):
|
||||||
|
if self.has_variations:
|
||||||
|
cloud_variations = defaultdict(list)
|
||||||
if not os.path.isdir(pointcloud):
|
if not os.path.isdir(pointcloud):
|
||||||
continue
|
continue
|
||||||
data, paths = None, list()
|
data, paths = None, list()
|
||||||
@ -131,8 +135,14 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
data = self._transform_and_filter(data)
|
data = self._transform_and_filter(data)
|
||||||
if self.collate_per_element:
|
if self.collate_per_element:
|
||||||
datasets[data_folder].append(data)
|
datasets[data_folder].append(data)
|
||||||
|
if self.has_variations:
|
||||||
|
cloud_variations[int(element.name.split('_')[0])].append(data)
|
||||||
if not self.collate_per_element:
|
if not self.collate_per_element:
|
||||||
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
if self.has_variations:
|
||||||
|
for variation in cloud_variations.keys():
|
||||||
|
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
|
else:
|
||||||
|
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||||
|
|
||||||
if datasets[data_folder]:
|
if datasets[data_folder]:
|
||||||
os.makedirs(self.processed_dir, exist_ok=True)
|
os.makedirs(self.processed_dir, exist_ok=True)
|
||||||
@ -147,11 +157,12 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
Resample raw point cloud to fixed number of points.
|
Resample raw point cloud to fixed number of points.
|
||||||
Map raw label from range [1, N] to [0, N-1].
|
Map raw label from range [1, N] to [0, N-1].
|
||||||
"""
|
"""
|
||||||
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, npoints=1024, headers=True):
|
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None,
|
||||||
|
has_variations=False, npoints=1024, headers=True):
|
||||||
super(ShapeNetPartSegDataset, self).__init__()
|
super(ShapeNetPartSegDataset, self).__init__()
|
||||||
self.npoints = npoints
|
self.npoints = npoints
|
||||||
self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment,
|
self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment,
|
||||||
train=train, transform=transform, headers=headers)
|
train=train, transform=transform, headers=headers, has_variations=has_variations)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
data = self.dataset[index]
|
data = self.dataset[index]
|
||||||
|
9
main.py
9
main.py
@ -38,6 +38,9 @@ parser.add_argument('--test_per_batches', type=int, default=1000, help='run a te
|
|||||||
parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers')
|
parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers')
|
||||||
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
|
||||||
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
|
||||||
|
parser.add_argument('--has_variations', type=strtobool, default=False,
|
||||||
|
help='whether a single pointcloud has variations '
|
||||||
|
'named int(id)_pc.(xyz|dat) look at pointclouds or sub')
|
||||||
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
@ -70,11 +73,13 @@ if __name__ == '__main__':
|
|||||||
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
test_transform = GT.Compose([GT.NormalizeScale(), ])
|
||||||
|
|
||||||
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment,
|
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment,
|
||||||
train=True, transform=train_transform, npoints=opt.npoints, headers=opt.headers)
|
train=True, transform=train_transform, npoints=opt.npoints,
|
||||||
|
has_variations=opt.has_variations, headers=opt.headers)
|
||||||
dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
||||||
|
|
||||||
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment,
|
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment,
|
||||||
train=False, transform=test_transform, npoints=opt.npoints, headers=opt.headers)
|
train=False, transform=test_transform, npoints=opt.npoints,
|
||||||
|
has_variations=opt.has_variations, headers=opt.headers)
|
||||||
test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
|
||||||
|
|
||||||
num_classes = dataset.num_classes()
|
num_classes = dataset.num_classes()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user