New Model running
This commit is contained in:
@ -11,6 +11,8 @@ from torch.utils.data import DataLoader
|
||||
# =============================================================================
|
||||
|
||||
# Transforms
|
||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||
|
||||
from ml_lib.point_toolset.point_io import BatchToData
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
|
||||
@ -18,21 +20,12 @@ from ml_lib.utils.model_io import SavedLightningModels
|
||||
# Datasets
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from models import PointNet2
|
||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
||||
label2color
|
||||
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color
|
||||
from utils.project_settings import GlobalVar
|
||||
|
||||
|
||||
def prepare_dataloader(config_obj):
|
||||
dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test,
|
||||
setting=GlobalVar.settings[config_obj.model.type])
|
||||
# noinspection PyTypeChecker
|
||||
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
||||
num_workers=config_obj.data.worker, shuffle=False)
|
||||
|
||||
|
||||
def restore_logger_and_model(log_dir):
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-1)
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, model=PointNet2, n=-5)
|
||||
model = model.restore()
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
@ -40,26 +33,30 @@ def restore_logger_and_model(log_dir):
|
||||
model.cpu()
|
||||
return model
|
||||
|
||||
|
||||
def predict_prim_type(input_pc, model):
|
||||
|
||||
input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
|
||||
pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
|
||||
)
|
||||
input_data = dict(
|
||||
norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)).unsqueeze(0),
|
||||
pos=torch.tensor(input_pc[:, 0:3]).unsqueeze(0),
|
||||
)
|
||||
|
||||
batch_to_data = BatchToData()
|
||||
|
||||
data = batch_to_data(input_data)
|
||||
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy()
|
||||
y_primary = torch.argmax(y.main_out, dim=-1).cpu().numpy()
|
||||
|
||||
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
|
||||
if input_pc.shape[1] > 6:
|
||||
input_pc = input_pc[:, :6]
|
||||
return np.concatenate((input_pc, y_primary.reshape(-1, 1)), axis=-1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
input_pc_path = Path('data') / 'pc' / 'test.xyz'
|
||||
|
||||
model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0'
|
||||
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
|
||||
# config_filename = 'config.ini'
|
||||
# config = ThisConfig()
|
||||
# config.read_file((Path(model_path) / config_filename).open('r'))
|
||||
@ -71,8 +68,9 @@ if __name__ == '__main__':
|
||||
# input_pc = normalize_pointcloud(input_pc)
|
||||
|
||||
# TEST DATASET
|
||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||
npoints=1024, refresh=True)
|
||||
transforms = Compose([NormalizeScale(), ])
|
||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=True,
|
||||
refresh=True, transform=transforms)
|
||||
|
||||
grid_clusters = cluster_cubes(test_dataset[0], [3, 3, 3], max_points_per_cluster=1024)
|
||||
|
||||
@ -84,8 +82,7 @@ if __name__ == '__main__':
|
||||
|
||||
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
|
||||
|
||||
#pc_with_prim_type = polytopes_to_planes(pc_with_prim_type)
|
||||
|
||||
# pc_with_prim_type = polytopes_to_planes(pc_with_prim_type)
|
||||
pc_with_prim_type = append_onehotencoded_type(pc_with_prim_type)
|
||||
|
||||
pc = ps.register_point_cloud("points_" + str(i), pc_with_prim_type[:, :3], radius=0.01)
|
||||
|
Reference in New Issue
Block a user