from pathlib import Path import torch from torch_geometric.data import Data from tqdm import tqdm from torch.utils.data import DataLoader # Dataset and Dataloaders # ============================================================================= # Transforms from ml_lib.point_toolset.point_io import BatchToData from ml_lib.utils.model_io import SavedLightningModels # Datasets from datasets.shapenet import ShapeNetPartSegDataset from utils.project_config import ThisConfig raise BrokenPipeError('There are Imports that need to be fixed first!!!!') def prepare_dataloader(config_obj): dataset = ShapeNetPartSegDataset(config_obj.data.root, mode=GlobalVar.data_split.test, setting=GlobalVar.settings[config_obj.model.type]) # noinspection PyTypeChecker return DataLoader(dataset, batch_size=config_obj.train.batch_size, num_workers=config_obj.data.worker, shuffle=False) def restore_logger_and_model(log_dir): model = SavedLightningModels.load_checkpoint(models_root_path=log_dir, n=-1) model = model.restore() if torch.cuda.is_available(): model.cuda() else: model.cpu() return model if __name__ == '__main__': outpath = Path('output') model_path = Path('/home/steffen/projects/point_to_primitive/output/P2G/PG_9f7ac027e3359fa5f5e5bcd32044a167/version_69') config_filename = 'config.ini' inference_out = 'manual_test_out.csv' config = ThisConfig() config.read_file((Path(model_path) / config_filename).open('r')) test_dataloader = prepare_dataloader(config) loaded_model = restore_logger_and_model(model_path) loaded_model.eval() with (model_path / inference_out).open(mode='w') as outfile: outfile.write(f'{",".join(FullCloudsDataset.headers[:6])},class,cluster\n') batch_to_data = BatchToData() for batch_pos_x_n_y_c in tqdm(test_dataloader, total=len(test_dataloader)): data = batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = loaded_model(data.to(device='cuda' if torch.cuda.is_available() else 'cpu')) y_primary = torch.argmax(y.main_out, dim=-1).squeeze().cpu().numpy() y_sec = -1 try: y_sec = torch.argmax(y.grid_out, dim=-1).squeeze().cpu().numpy() except AttributeError: pass try: y_sec = torch.argmax(y.prim_out, dim=-1).squeeze().cpu().numpy() except AttributeError: pass for row in range(data.num_nodes): outfile.write(f'{",".join(map(str, data.pos[row].tolist()))},' + f'{",".join(map(str, data.x[row].tolist()))},' + f'{y_primary[row]},{y_sec[row]}\n') print('Done')