Refactured Settings File
This commit is contained in:
@ -1,8 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from tqdm import tqdm
|
||||
|
||||
import polyscope as ps
|
||||
import numpy as np
|
||||
@ -20,8 +18,8 @@ from ml_lib.utils.model_io import SavedLightningModels
|
||||
# Datasets
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
||||
label2color, polytopes_to_planes
|
||||
from utils.project_config import GlobalVar, ThisConfig
|
||||
label2color
|
||||
from utils.project_settings import GlobalVar
|
||||
|
||||
|
||||
def prepare_dataloader(config_obj):
|
||||
@ -56,14 +54,15 @@ def predict_prim_type(input_pc, model):
|
||||
|
||||
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
input_pc_path = 'data/pc/pc.txt'
|
||||
input_pc_path = Path('data') / 'pc' / 'pc.txt'
|
||||
|
||||
model_path = Path('trained_models/version_1')
|
||||
config_filename = 'config.ini'
|
||||
config = ThisConfig()
|
||||
config.read_file((Path(model_path) / config_filename).open('r'))
|
||||
model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0'
|
||||
# config_filename = 'config.ini'
|
||||
# config = ThisConfig()
|
||||
# config.read_file((Path(model_path) / config_filename).open('r'))
|
||||
loaded_model = restore_logger_and_model(model_path)
|
||||
loaded_model.eval()
|
||||
|
||||
|
Reference in New Issue
Block a user