Inference written
This commit is contained in:
parent
2a767bead2
commit
27ae8467fc
73
main_inference.py
Normal file
73
main_inference.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Dataset and Dataloaders
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Transforms
|
||||||
|
from ml_lib.point_toolset.point_io import BatchToData
|
||||||
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
|
|
||||||
|
|
||||||
|
# Datasets
|
||||||
|
from datasets.full_pointclouds import FullCloudsDataset
|
||||||
|
from utils.project_config import GlobalVar, ThisConfig
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataloader(config_obj):
|
||||||
|
dataset = FullCloudsDataset(config_obj.data.root, split=GlobalVar.data_split.test,
|
||||||
|
setting=GlobalVar.settings[config_obj.model.type])
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
||||||
|
num_workers=config_obj.data.worker, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_logger_and_model(log_dir):
|
||||||
|
model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1)
|
||||||
|
model = model.restore()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model.cuda()
|
||||||
|
else:
|
||||||
|
model.cpu()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
outpath = Path('output')
|
||||||
|
model_path = Path('/home/steffen/projects/point_to_primitive/output/P2G/PG_9f7ac027e3359fa5f5e5bcd32044a167/version_69')
|
||||||
|
config_filename = 'config.ini'
|
||||||
|
inference_out = 'manual_test_out.csv'
|
||||||
|
|
||||||
|
config = ThisConfig()
|
||||||
|
config.read_file((Path(model_path) / config_filename).open('r'))
|
||||||
|
test_dataloader = prepare_dataloader(config)
|
||||||
|
|
||||||
|
loaded_model = restore_logger_and_model(model_path)
|
||||||
|
loaded_model.eval()
|
||||||
|
|
||||||
|
with (model_path / inference_out).open(mode='w') as outfile:
|
||||||
|
outfile.write(f'{",".join(FullCloudsDataset.headers[:6])},class,cluster\n')
|
||||||
|
batch_to_data = BatchToData()
|
||||||
|
for batch_pos_x_n_y_c in tqdm(test_dataloader, total=len(test_dataloader)):
|
||||||
|
data = batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
||||||
|
y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu'))
|
||||||
|
y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy()
|
||||||
|
y_sec = -1
|
||||||
|
try:
|
||||||
|
y_sec = torch.argmax(y.grid_out, dim=-1).squeeze().cpu().numpy()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
y_sec = torch.argmax(y.prim_out, dim=-1).squeeze().cpu().numpy()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
for row in range(data.num_nodes):
|
||||||
|
outfile.write(f'{",".join(map(str, data.pos[row].tolist()))},' +
|
||||||
|
f'{",".join(map(str, data.x[row].tolist()))},' +
|
||||||
|
f'{y_primary[row]},{y_sec[row]}\n')
|
||||||
|
print('Done')
|
@ -17,6 +17,9 @@ class DataClass(Namespace):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__getattribute__(item)
|
||||||
|
|
||||||
|
|
||||||
class Classes(DataClass):
|
class Classes(DataClass):
|
||||||
# Object Classes for Point Segmentation
|
# Object Classes for Point Segmentation
|
||||||
@ -29,6 +32,12 @@ class Classes(DataClass):
|
|||||||
Plane = 6
|
Plane = 6
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(DataClass):
|
||||||
|
P2G = 'grid'
|
||||||
|
P2P = 'prim'
|
||||||
|
PN2 = 'pc'
|
||||||
|
|
||||||
|
|
||||||
class DataSplit(DataClass):
|
class DataSplit(DataClass):
|
||||||
# DATA SPLIT OPTIONS
|
# DATA SPLIT OPTIONS
|
||||||
train = 'train'
|
train = 'train'
|
||||||
@ -49,6 +58,8 @@ class GlobalVar(DataClass):
|
|||||||
|
|
||||||
prim_count = -1
|
prim_count = -1
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
from models import *
|
from models import *
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user