diff --git a/_parameters.py b/_parameters.py index d686425..b39745b 100644 --- a/_parameters.py +++ b/_parameters.py @@ -25,9 +25,9 @@ main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="") -main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="") -main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--data_cluster_type", type=str, default='prim', help="") +main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="") # Transformations # main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") diff --git a/main.py b/main.py index 4f7c3e4..53f78f1 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,7 @@ def run_lightning_loop(config_obj): checkpoint_callback = ModelCheckpoint( monitor='mean_loss', filepath=str(logger.log_dir / 'ckpt_weights'), - verbose=True, save_top_k=10, + verbose=True, save_top_k=3, ) # ============================================================================= diff --git a/main_pipeline.py b/main_pipeline.py index 68d5909..8d6a71b 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -67,15 +67,17 @@ if __name__ == '__main__': type_cluster_min_pts = 50 model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0' + loaded_model = restore_logger_and_model(model_path) loaded_model.eval() transforms = Compose([NormalizeScale(), ]) test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, - refresh=True, transform=transforms) + refresh=True, transform=transforms) # , cluster_type='pc') grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts) + ps.init() # ========================== Grid Clustering ========================== diff --git a/requirements.txt b/requirements.txt index b01caa8..a456de7 100644 Binary files a/requirements.txt and b/requirements.txt differ