Now with included sorting per cloud variation

This commit is contained in:
Si11ium 2019-08-02 13:03:56 +02:00
parent 9c100b6c43
commit 43b4c8031a
2 changed files with 22 additions and 6 deletions

View File

@ -17,13 +17,15 @@ def save_names(name_list, path):
with open(path, 'wb') as f: with open(path, 'wb') as f:
f.writelines(name_list) f.writelines(name_list)
class CustomShapeNet(InMemoryDataset): class CustomShapeNet(InMemoryDataset):
categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])} categories = {key: val for val, key in enumerate(['Box', 'Cone', 'Cylinder', 'Sphere'])}
def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None, def __init__(self, root, collate_per_segment=True, train=True, transform=None, pre_filter=None, pre_transform=None,
headers=True, **kwargs): headers=True, has_variations=False):
self.has_headers = headers self.has_headers = headers
self.has_variations = has_variations
self.collate_per_element = collate_per_segment self.collate_per_element = collate_per_segment
self.train = train self.train = train
super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter) super(CustomShapeNet, self).__init__(root, transform, pre_transform, pre_filter)
@ -88,6 +90,8 @@ class CustomShapeNet(InMemoryDataset):
pass pass
for pointcloud in tqdm(os.scandir(path_to_clouds)): for pointcloud in tqdm(os.scandir(path_to_clouds)):
if self.has_variations:
cloud_variations = defaultdict(list)
if not os.path.isdir(pointcloud): if not os.path.isdir(pointcloud):
continue continue
data, paths = None, list() data, paths = None, list()
@ -131,8 +135,14 @@ class CustomShapeNet(InMemoryDataset):
data = self._transform_and_filter(data) data = self._transform_and_filter(data)
if self.collate_per_element: if self.collate_per_element:
datasets[data_folder].append(data) datasets[data_folder].append(data)
if self.has_variations:
cloud_variations[int(element.name.split('_')[0])].append(data)
if not self.collate_per_element: if not self.collate_per_element:
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if self.has_variations:
for variation in cloud_variations.keys():
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
else:
datasets[data_folder].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
if datasets[data_folder]: if datasets[data_folder]:
os.makedirs(self.processed_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True)
@ -147,11 +157,12 @@ class ShapeNetPartSegDataset(Dataset):
Resample raw point cloud to fixed number of points. Resample raw point cloud to fixed number of points.
Map raw label from range [1, N] to [0, N-1]. Map raw label from range [1, N] to [0, N-1].
""" """
def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None, npoints=1024, headers=True): def __init__(self, root_dir, collate_per_segment=True, train=True, transform=None,
has_variations=False, npoints=1024, headers=True):
super(ShapeNetPartSegDataset, self).__init__() super(ShapeNetPartSegDataset, self).__init__()
self.npoints = npoints self.npoints = npoints
self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment, self.dataset = CustomShapeNet(root=root_dir, collate_per_segment=collate_per_segment,
train=train, transform=transform, headers=headers) train=train, transform=transform, headers=headers, has_variations=has_variations)
def __getitem__(self, index): def __getitem__(self, index):
data = self.dataset[index] data = self.dataset[index]

View File

@ -38,6 +38,9 @@ parser.add_argument('--test_per_batches', type=int, default=1000, help='run a te
parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers') parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers')
parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers') parser.add_argument('--headers', type=strtobool, default=True, help='if raw files come with headers')
parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub') parser.add_argument('--collate_per_segment', type=strtobool, default=True, help='whether to look at pointclouds or sub')
parser.add_argument('--has_variations', type=strtobool, default=False,
help='whether a single pointcloud has variations '
'named int(id)_pc.(xyz|dat) look at pointclouds or sub')
opt = parser.parse_args() opt = parser.parse_args()
@ -70,11 +73,13 @@ if __name__ == '__main__':
test_transform = GT.Compose([GT.NormalizeScale(), ]) test_transform = GT.Compose([GT.NormalizeScale(), ])
dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment,
train=True, transform=train_transform, npoints=opt.npoints, headers=opt.headers) train=True, transform=train_transform, npoints=opt.npoints,
has_variations=opt.has_variations, headers=opt.headers)
dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) dataLoader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment, test_dataset = ShapeNetPartSegDataset(root_dir=opt.dataset, collate_per_segment=opt.collate_per_segment,
train=False, transform=test_transform, npoints=opt.npoints, headers=opt.headers) train=False, transform=test_transform, npoints=opt.npoints,
has_variations=opt.has_variations, headers=opt.headers)
test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) test_dataLoader = DataLoader(test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
num_classes = dataset.num_classes() num_classes = dataset.num_classes()