explicit model argument
This commit is contained in:
parent
b79141e854
commit
fe2bc131df
@ -17,6 +17,7 @@ from ml_lib.utils.model_io import SavedLightningModels
|
||||
|
||||
# Datasets
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from models import PointNet2
|
||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
||||
label2color
|
||||
from utils.project_settings import GlobalVar
|
||||
@ -31,7 +32,7 @@ def prepare_dataloader(config_obj):
|
||||
|
||||
|
||||
def restore_logger_and_model(log_dir):
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1)
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
||||
model = model.restore()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
|
Loading…
x
Reference in New Issue
Block a user