diff --git a/main_pipeline.py b/main_pipeline.py index 0d62aa8..03ef08a 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -17,6 +17,7 @@ from ml_lib.utils.model_io import SavedLightningModels # Datasets from datasets.shapenet import ShapeNetPartSegDataset +from models import PointNet2 from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ label2color from utils.project_settings import GlobalVar @@ -31,7 +32,7 @@ def prepare_dataloader(config_obj): def restore_logger_and_model(log_dir): - model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1) + model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1) model = model.restore() if torch.cuda.is_available(): model.cuda()