inference build

This commit is contained in:
Si11ium
2020-06-23 21:05:49 +02:00
parent 1033b26195
commit 85cf3128f1
2 changed files with 4 additions and 7 deletions

@ -25,7 +25,7 @@ from utils.project_settings import GlobalVar
def restore_logger_and_model(log_dir):
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5)
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
model = model.restore()
if torch.cuda.is_available():
model.cuda()
@ -69,10 +69,10 @@ if __name__ == '__main__':
# TEST DATASET
transforms = Compose([NormalizeScale(), ])
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True,
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms)
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
grid_clusters = cluster_cubes(test_dataset[1], [3, 3, 3], max_points_per_cluster=1024)
ps.init()