From fe2bc131dfa5afa69335fd7a59fceae6765e28c6 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Fri, 19 Jun 2020 09:41:23 +0200 Subject: [PATCH] explicit model argument --- main_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main_pipeline.py b/main_pipeline.py index 0d62aa8..03ef08a 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -17,6 +17,7 @@ from ml_lib.utils.model_io import SavedLightningModels # Datasets from datasets.shapenet import ShapeNetPartSegDataset +from models import PointNet2 from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ label2color from utils.project_settings import GlobalVar @@ -31,7 +32,7 @@ def prepare_dataloader(config_obj): def restore_logger_and_model(log_dir): - model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1) + model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1) model = model.restore() if torch.cuda.is_available(): model.cuda()