explicit model argument

This commit is contained in:
Si11ium 2020-06-19 09:41:23 +02:00
parent b79141e854
commit fe2bc131df

View File

@ -17,6 +17,7 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets # Datasets
from datasets.shapenet import ShapeNetPartSegDataset from datasets.shapenet import ShapeNetPartSegDataset
from models import PointNet2
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
label2color label2color
from utils.project_settings import GlobalVar from utils.project_settings import GlobalVar
@ -31,7 +32,7 @@ def prepare_dataloader(config_obj):
def restore_logger_and_model(log_dir): def restore_logger_and_model(log_dir):
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1) model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
model = model.restore() model = model.restore()
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()