6D prediction files now working
This commit is contained in:
parent
2a7a236b89
commit
358d692699
@ -26,8 +26,8 @@ main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
|||||||
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
||||||
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
||||||
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="")
|
||||||
|
|
||||||
# Transformations
|
# Transformations
|
||||||
# main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
# main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||||
|
2
main.py
2
main.py
@ -27,7 +27,7 @@ def run_lightning_loop(config_obj):
|
|||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
monitor='mean_loss',
|
monitor='mean_loss',
|
||||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||||
verbose=True, save_top_k=10,
|
verbose=True, save_top_k=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
@ -56,7 +56,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
# input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
||||||
|
|
||||||
model_path = Path('output') / 'PN2' / 'PN_14628b734c5b651b013ad9e36c406934' / 'version_0'
|
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
|
||||||
# config_filename = 'config.ini'
|
# config_filename = 'config.ini'
|
||||||
# config = ThisConfig()
|
# config = ThisConfig()
|
||||||
# config.read_file((Path(model_path) / config_filename).open('r'))
|
# config.read_file((Path(model_path) / config_filename).open('r'))
|
||||||
@ -70,7 +70,7 @@ if __name__ == '__main__':
|
|||||||
# TEST DATASET
|
# TEST DATASET
|
||||||
transforms = Compose([NormalizeScale(), ])
|
transforms = Compose([NormalizeScale(), ])
|
||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms)
|
refresh=True, transform=transforms) # , cluster_type='pc')
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[0], [1, 1, 1], max_points_per_cluster=8192)
|
grid_clusters = cluster_cubes(test_dataset[0], [1, 1, 1], max_points_per_cluster=8192)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user