Dataset Redone

This commit is contained in:
Si11ium
2020-06-19 08:17:35 +02:00
parent 4898e98851
commit 63605ae33a
14 changed files with 239 additions and 362 deletions

View File

@ -18,14 +18,14 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets
from datasets.full_pointclouds import FullCloudsDataset
from datasets.shapenet import ShapeNetPartSegDataset
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
label2color, polytopes_to_planes
from utils.project_config import GlobalVar, ThisConfig
def prepare_dataloader(config_obj):
dataset = FullCloudsDataset(config_obj.data.root, split=GlobalVar.data_split.test,
dataset = ShapeNetPartSegDataset(config_obj.data.root, split=GlobalVar.data_split.test,
setting=GlobalVar.settings[config_obj.model.type])
# noinspection PyTypeChecker
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
@ -43,15 +43,14 @@ def restore_logger_and_model(log_dir):
def predict_prim_type(input_pc, model):
input_data = (
torch.tensor(np.array([input_pc[:, 3:6]], np.float)),
torch.tensor(input_pc[:, 0:3]),
np.zeros(input_pc.shape[0])
)
input_data = dict(norm=torch.tensor(np.array([input_pc[:, 3:6]], np.float)),
pos=torch.tensor(input_pc[:, 0:3]),
y=np.zeros(input_pc.shape[0])
)
batch_to_data = BatchToData()
data = batch_to_data(input_data[0], input_data[1], input_data[2])
data = batch_to_data(input_data)
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy()