From 85cf3128f167869bf02865a8f42f5539bc060906 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Tue, 23 Jun 2020 21:05:49 +0200 Subject: [PATCH] inference build --- datasets/shapenet.py | 5 +---- main_pipeline.py | 6 +++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/datasets/shapenet.py b/datasets/shapenet.py index 8f049ef..3f8e0df 100644 --- a/datasets/shapenet.py +++ b/datasets/shapenet.py @@ -198,10 +198,7 @@ class CustomShapeNet(InMemoryDataset): if self.collate_per_segment: datasets[self.mode].append(data) if not self.collate_per_segment: - # This is just to be sure, but should not be needed, since src[all] == all - raise TypeError('FIX THIS') - # old Code - # datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) + datasets[self.mode].append(Data(**{key: torch.cat(data[key]) for key in data.keys()})) if datasets[self.mode]: os.makedirs(self.processed_dir, exist_ok=True) diff --git a/main_pipeline.py b/main_pipeline.py index 752efaf..82d31c0 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -25,7 +25,7 @@ from utils.project_settings import GlobalVar def restore_logger_and_model(log_dir): - model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5) + model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1) model = model.restore() if torch.cuda.is_available(): model.cuda() @@ -69,10 +69,10 @@ if __name__ == '__main__': # TEST DATASET transforms = Compose([NormalizeScale(), ]) - test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True, + test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, refresh=True, transform=transforms) - grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024) + grid_clusters = cluster_cubes(test_dataset[1], [3, 3, 3], max_points_per_cluster=1024) ps.init()