inference build
This commit is contained in:
parent
1033b26195
commit
85cf3128f1
@ -198,10 +198,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
if self.collate_per_segment:
|
||||
datasets[self.mode].append(data)
|
||||
if not self.collate_per_segment:
|
||||
# This is just to be sure, but should not be needed, since src[all] == all
|
||||
raise TypeError('FIX THIS')
|
||||
# old Code
|
||||
# datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()}))
|
||||
|
||||
if datasets[self.mode]:
|
||||
os.makedirs(self.processed_dir, exist_ok=True)
|
||||
|
@ -25,7 +25,7 @@ from utils.project_settings import GlobalVar
|
||||
|
||||
|
||||
def restore_logger_and_model(log_dir):
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5)
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
||||
model = model.restore()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
@ -69,10 +69,10 @@ if __name__ == '__main__':
|
||||
|
||||
# TEST DATASET
|
||||
transforms = Compose([NormalizeScale(), ])
|
||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True,
|
||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||
refresh=True, transform=transforms)
|
||||
|
||||
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
|
||||
grid_clusters = cluster_cubes(test_dataset[1], [3, 3, 3], max_points_per_cluster=1024)
|
||||
|
||||
ps.init()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user