diff --git a/main_inference.py b/main_inference.py new file mode 100644 index 0000000..7b72616 --- /dev/null +++ b/main_inference.py @@ -0,0 +1,73 @@ +from pathlib import Path + +import torch +from torch_geometric.data import Data +from tqdm import tqdm + +from torch.utils.data import DataLoader + +# Dataset and Dataloaders +# ============================================================================= + +# Transforms +from ml_lib.point_toolset.point_io import BatchToData +from ml_lib.utils.model_io import SavedLightningModels + + +# Datasets +from datasets.full_pointclouds import FullCloudsDataset +from utils.project_config import GlobalVar, ThisConfig + + +def prepare_dataloader(config_obj): + dataset = FullCloudsDataset(config_obj.data.root, split=GlobalVar.data_split.test, + setting=GlobalVar.settings[config_obj.model.type]) + # noinspection PyTypeChecker + return DataLoader(dataset, batch_size=config_obj.train.batch_size, + num_workers=config_obj.data.worker, shuffle=False) + + +def restore_logger_and_model(log_dir): + model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1) + model = model.restore() + if torch.cuda.is_available(): + model.cuda() + else: + model.cpu() + return model + + +if __name__ == '__main__': + outpath = Path('output') + model_path = Path('/home/steffen/projects/point_to_primitive/output/P2G/PG_9f7ac027e3359fa5f5e5bcd32044a167/version_69') + config_filename = 'config.ini' + inference_out = 'manual_test_out.csv' + + config = ThisConfig() + config.read_file((Path(model_path) / config_filename).open('r')) + test_dataloader = prepare_dataloader(config) + + loaded_model = restore_logger_and_model(model_path) + loaded_model.eval() + + with (model_path / inference_out).open(mode='w') as outfile: + outfile.write(f'{",".join(FullCloudsDataset.headers[:6])},class,cluster\n') + batch_to_data = BatchToData() + for batch_pos_x_n_y_c in tqdm(test_dataloader, total=len(test_dataloader)): + data = batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c + y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu')) + y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy() + y_sec = -1 + try: + y_sec = torch.argmax(y.grid_out, dim=-1).squeeze().cpu().numpy() + except AttributeError: + pass + try: + y_sec = torch.argmax(y.prim_out, dim=-1).squeeze().cpu().numpy() + except AttributeError: + pass + for row in range(data.num_nodes): + outfile.write(f'{",".join(map(str, data.pos[row].tolist()))},' + + f'{",".join(map(str, data.x[row].tolist()))},' + + f'{y_primary[row]},{y_sec[row]}\n') + print('Done') diff --git a/utils/project_config.py b/utils/project_config.py index 5cd0b11..c2c96d1 100644 --- a/utils/project_config.py +++ b/utils/project_config.py @@ -17,6 +17,9 @@ class DataClass(Namespace): def __repr__(self): return f'{self.__class__.__name__}({self.__dict__().__repr__()})' + def __getitem__(self, item): + return self.__getattribute__(item) + class Classes(DataClass): # Object Classes for Point Segmentation @@ -29,6 +32,12 @@ class Classes(DataClass): Plane = 6 +class Settings(DataClass): + P2G = 'grid' + P2P = 'prim' + PN2 = 'pc' + + class DataSplit(DataClass): # DATA SPLIT OPTIONS train = 'train' @@ -49,6 +58,8 @@ class GlobalVar(DataClass): prim_count = -1 + settings = Settings() + from models import *