explicit model argument
This commit is contained in:
		| @@ -17,6 +17,7 @@ from ml_lib.utils.model_io import SavedLightningModels | |||||||
|  |  | ||||||
| # Datasets | # Datasets | ||||||
| from datasets.shapenet import ShapeNetPartSegDataset | from datasets.shapenet import ShapeNetPartSegDataset | ||||||
|  | from models import PointNet2 | ||||||
| from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ | from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ | ||||||
|     label2color |     label2color | ||||||
| from utils.project_settings import GlobalVar | from utils.project_settings import GlobalVar | ||||||
| @@ -31,7 +32,7 @@ def prepare_dataloader(config_obj): | |||||||
|  |  | ||||||
|  |  | ||||||
| def restore_logger_and_model(log_dir): | def restore_logger_and_model(log_dir): | ||||||
|     model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1) |     model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1) | ||||||
|     model = model.restore() |     model = model.restore() | ||||||
|     if torch.cuda.is_available(): |     if torch.cuda.is_available(): | ||||||
|         model.cuda() |         model.cuda() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Si11ium
					Si11ium