explicit model argument
This commit is contained in:
parent
b79141e854
commit
fe2bc131df
@ -17,6 +17,7 @@ from ml_lib.utils.model_io import SavedLightningModels
|
|||||||
|
|
||||||
# Datasets
|
# Datasets
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
|
from models import PointNet2
|
||||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
||||||
label2color
|
label2color
|
||||||
from utils.project_settings import GlobalVar
|
from utils.project_settings import GlobalVar
|
||||||
@ -31,7 +32,7 @@ def prepare_dataloader(config_obj):
|
|||||||
|
|
||||||
|
|
||||||
def restore_logger_and_model(log_dir):
|
def restore_logger_and_model(log_dir):
|
||||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1)
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
||||||
model = model.restore()
|
model = model.restore()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user