Dataset Redone
This commit is contained in:
@ -18,14 +18,14 @@ from ml_lib.utils.model_io import SavedLightningModels
|
||||
|
||||
|
||||
# Datasets
|
||||
from datasets.full_pointclouds import FullCloudsDataset
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
||||
label2color, polytopes_to_planes
|
||||
from utils.project_config import GlobalVar, ThisConfig
|
||||
|
||||
|
||||
def prepare_dataloader(config_obj):
|
||||
dataset = FullCloudsDataset(config_obj.data.root, split=GlobalVar.data_split.test,
|
||||
dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test,
|
||||
setting=GlobalVar.settings[config_obj.model.type])
|
||||
# noinspection PyTypeChecker
|
||||
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
||||
@ -43,15 +43,14 @@ def restore_logger_and_model(log_dir):
|
||||
|
||||
def predict_prim_type(input_pc, model):
|
||||
|
||||
input_data = (
|
||||
torch.tensor(np.array([input_pc[:, 3:6]], np.float)),
|
||||
torch.tensor(input_pc[:, 0:3]),
|
||||
np.zeros(input_pc.shape[0])
|
||||
)
|
||||
input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)),
|
||||
pos=torch.tensor(input_pc[:, 0:3]),
|
||||
y=np.zeros(input_pc.shape[0])
|
||||
)
|
||||
|
||||
batch_to_data = BatchToData()
|
||||
|
||||
data = batch_to_data(input_data[0], input_data[1], input_data[2])
|
||||
data = batch_to_data(input_data)
|
||||
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy()
|
||||
|
||||
|
Reference in New Issue
Block a user