From b79141e8549345813b2c16c17afada48e901f08d Mon Sep 17 00:00:00 2001 From: Si11ium Date: Fri, 19 Jun 2020 09:39:14 +0200 Subject: [PATCH] Refactured Settings File --- datasets/shapenet.py | 10 +++--- main_inference.py | 3 +- main_pipeline.py | 17 +++++------ models/point_net_2.py | 2 +- utils/module_mixins.py | 4 +-- utils/pointcloud.py | 2 +- utils/project_config.py | 62 -------------------------------------- utils/project_settings.py | 61 +++++++++++++++++++++++++++++++++++++ utils/validation_mixins.py | 0 9 files changed, 80 insertions(+), 81 deletions(-) create mode 100644 utils/project_settings.py delete mode 100644 utils/validation_mixins.py diff --git a/datasets/shapenet.py b/datasets/shapenet.py index 0097ce3..8917d79 100644 --- a/datasets/shapenet.py +++ b/datasets/shapenet.py @@ -5,16 +5,15 @@ import numpy as np from collections import defaultdict import os +from torch.utils.data import Dataset from tqdm import tqdm import glob import torch from torch_geometric.data import InMemoryDataset from torch_geometric.data import Data -from torch.utils.data import Dataset -import re -from utils.project_config import Classes, DataSplit +from utils.project_settings import Classes, DataSplit def save_names(name_list, path): @@ -198,12 +197,13 @@ class ShapeNetPartSegDataset(Dataset): # y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1] - sample = Data(**dict(pos=pos, # torch.Tensor (n, 3/6) + data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6) y=y, # torch.Tensor (n,) norm=norm # torch.Tensor (n, 3/0) ) ) - return sample + + return data def __len__(self): return len(self.dataset) diff --git a/main_inference.py b/main_inference.py index 7e9c86c..859fcb7 100644 --- a/main_inference.py +++ b/main_inference.py @@ -16,7 +16,8 @@ from ml_lib.utils.model_io import SavedLightningModels # Datasets from datasets.shapenet import ShapeNetPartSegDataset -from utils.project_config import GlobalVar, ThisConfig +from utils.project_config import ThisConfig +from utils.project_settings import GlobalVar def prepare_dataloader(config_obj): diff --git a/main_pipeline.py b/main_pipeline.py index 8e33455..0d62aa8 100644 --- a/main_pipeline.py +++ b/main_pipeline.py @@ -1,8 +1,6 @@ from pathlib import Path import torch -from torch_geometric.data import Data -from tqdm import tqdm import polyscope as ps import numpy as np @@ -20,8 +18,8 @@ from ml_lib.utils.model_io import SavedLightningModels # Datasets from datasets.shapenet import ShapeNetPartSegDataset from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ - label2color, polytopes_to_planes -from utils.project_config import GlobalVar, ThisConfig + label2color +from utils.project_settings import GlobalVar def prepare_dataloader(config_obj): @@ -56,14 +54,15 @@ def predict_prim_type(input_pc, model): return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1) + if __name__ == '__main__': - input_pc_path = 'data/pc/pc.txt' + input_pc_path = Path('data') / 'pc' / 'pc.txt' - model_path = Path('trained_models/version_1') - config_filename = 'config.ini' - config = ThisConfig() - config.read_file((Path(model_path) / config_filename).open('r')) + model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0' + # config_filename = 'config.ini' + # config = ThisConfig() + # config.read_file((Path(model_path) / config_filename).open('r')) loaded_model = restore_logger_and_model(model_path) loaded_model.eval() diff --git a/models/point_net_2.py b/models/point_net_2.py index 4b13c39..aad346c 100644 --- a/models/point_net_2.py +++ b/models/point_net_2.py @@ -7,7 +7,7 @@ from datasets.shapenet import ShapeNetPartSegDataset from models._point_net_2 import _PointNetCore from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin -from utils.project_config import GlobalVar +from utils.project_settings import GlobalVar class PointNet2(BaseValMixin, diff --git a/utils/module_mixins.py b/utils/module_mixins.py index 3940163..921cba5 100644 --- a/utils/module_mixins.py +++ b/utils/module_mixins.py @@ -22,7 +22,7 @@ from torchcontrib.optim import SWA from ml_lib.modules.util import LightningBaseModule from ml_lib.utils.tools import to_one_hot -from .project_config import GlobalVar +from .project_settings import GlobalVar class BaseOptimizerMixin: @@ -86,7 +86,7 @@ class BaseValMixin: def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): assert isinstance(self, LightningBaseModule) - data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c + data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c y = self(data).main_out nll_loss = self.nll_loss(y, data.y) return dict(val_nll_loss=nll_loss, diff --git a/utils/pointcloud.py b/utils/pointcloud.py index 0c91799..8e28b4f 100644 --- a/utils/pointcloud.py +++ b/utils/pointcloud.py @@ -16,7 +16,7 @@ from pyod.models.hbos import HBOS from pyod.models.lscp import LSCP from pyod.models.feature_bagging import FeatureBagging -from utils.project_config import Classes +from utils.project_settings import Classes def polytopes_to_planes(pc): diff --git a/utils/project_config.py b/utils/project_config.py index d0c23d1..59af5b7 100644 --- a/utils/project_config.py +++ b/utils/project_config.py @@ -1,66 +1,4 @@ -from argparse import Namespace - from ml_lib.utils.config import Config - - -class DataClass(Namespace): - - def __len__(self): - return len(self.__dict__()) - - def __dict__(self): - return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key} - - def items(self): - return self.__dict__().items() - - def __repr__(self): - return f'{self.__class__.__name__}({self.__dict__().__repr__()})' - - def __getitem__(self, item): - return self.__getattribute__(item) - - -class Classes(DataClass): - # Object Classes for Point Segmentation - Sphere = 0 - Cylinder = 1 - Cone = 2 - Box = 3 - Polytope = 4 - Torus = 5 - Plane = 6 - - -class Settings(DataClass): - P2G = 'grid' - P2P = 'prim' - PN2 = 'pc' - - -class DataSplit(DataClass): - # DATA SPLIT OPTIONS - train = 'train' - devel = 'devel' - test = 'test' - - -class GlobalVar(DataClass): - # Variables for plotting - PADDING = 0.25 - DPI = 50 - - data_split = DataSplit() - - classes = Classes() - - grid_count = 12 - - prim_count = -1 - - settings = Settings() - - from models import * diff --git a/utils/project_settings.py b/utils/project_settings.py new file mode 100644 index 0000000..6033d8e --- /dev/null +++ b/utils/project_settings.py @@ -0,0 +1,61 @@ +from argparse import Namespace + +from ml_lib.utils.config import Config + + +class DataClass(Namespace): + + def __len__(self): + return len(self.__dict__()) + + def __dict__(self): + return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key} + + def items(self): + return self.__dict__().items() + + def __repr__(self): + return f'{self.__class__.__name__}({self.__dict__().__repr__()})' + + def __getitem__(self, item): + return self.__getattribute__(item) + + +class Classes(DataClass): + # Object Classes for Point Segmentation + Sphere = 0 + Cylinder = 1 + Cone = 2 + Box = 3 + Polytope = 4 + Torus = 5 + Plane = 6 + + +class Settings(DataClass): + P2G = 'grid' + P2P = 'prim' + PN2 = 'pc' + + +class DataSplit(DataClass): + # DATA SPLIT OPTIONS + train = 'train' + devel = 'devel' + test = 'test' + + +class GlobalVar(DataClass): + # Variables for plotting + PADDING = 0.25 + DPI = 50 + + data_split = DataSplit() + + classes = Classes() + + grid_count = 12 + + prim_count = -1 + + settings = Settings() \ No newline at end of file diff --git a/utils/validation_mixins.py b/utils/validation_mixins.py deleted file mode 100644 index e69de29..0000000