Refactured Settings File

This commit is contained in:
Si11ium 2020-06-19 09:39:14 +02:00
parent 63605ae33a
commit b79141e854
9 changed files with 80 additions and 81 deletions

View File

@ -5,16 +5,15 @@ import numpy as np
from collections import defaultdict from collections import defaultdict
import os import os
from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
import glob import glob
import torch import torch
from torch_geometric.data import InMemoryDataset from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data from torch_geometric.data import Data
from torch.utils.data import Dataset
import re
from utils.project_config import Classes, DataSplit from utils.project_settings import Classes, DataSplit
def save_names(name_list, path): def save_names(name_list, path):
@ -198,12 +197,13 @@ class ShapeNetPartSegDataset(Dataset):
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1] # y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
sample = Data(**dict(pos=pos, # torch.Tensor (n, 3/6) data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
y=y, # torch.Tensor (n,) y=y, # torch.Tensor (n,)
norm=norm # torch.Tensor (n, 3/0) norm=norm # torch.Tensor (n, 3/0)
) )
) )
return sample
return data
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)

View File

@ -16,7 +16,8 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets # Datasets
from datasets.shapenet import ShapeNetPartSegDataset from datasets.shapenet import ShapeNetPartSegDataset
from utils.project_config import GlobalVar, ThisConfig from utils.project_config import ThisConfig
from utils.project_settings import GlobalVar
def prepare_dataloader(config_obj): def prepare_dataloader(config_obj):

View File

@ -1,8 +1,6 @@
from pathlib import Path from pathlib import Path
import torch import torch
from torch_geometric.data import Data
from tqdm import tqdm
import polyscope as ps import polyscope as ps
import numpy as np import numpy as np
@ -20,8 +18,8 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets # Datasets
from datasets.shapenet import ShapeNetPartSegDataset from datasets.shapenet import ShapeNetPartSegDataset
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \ from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
label2color, polytopes_to_planes label2color
from utils.project_config import GlobalVar, ThisConfig from utils.project_settings import GlobalVar
def prepare_dataloader(config_obj): def prepare_dataloader(config_obj):
@ -56,14 +54,15 @@ def predict_prim_type(input_pc, model):
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1) return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
if __name__ == '__main__': if __name__ == '__main__':
input_pc_path = 'data/pc/pc.txt' input_pc_path = Path('data') / 'pc' / 'pc.txt'
model_path = Path('trained_models/version_1') model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0'
config_filename = 'config.ini' # config_filename = 'config.ini'
config = ThisConfig() # config = ThisConfig()
config.read_file((Path(model_path) / config_filename).open('r')) # config.read_file((Path(model_path) / config_filename).open('r'))
loaded_model = restore_logger_and_model(model_path) loaded_model = restore_logger_and_model(model_path)
loaded_model.eval() loaded_model.eval()

View File

@ -7,7 +7,7 @@ from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar from utils.project_settings import GlobalVar
class PointNet2(BaseValMixin, class PointNet2(BaseValMixin,

View File

@ -22,7 +22,7 @@ from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot from ml_lib.utils.tools import to_one_hot
from .project_config import GlobalVar from .project_settings import GlobalVar
class BaseOptimizerMixin: class BaseOptimizerMixin:
@ -86,7 +86,7 @@ class BaseValMixin:
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__): def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule) assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data).main_out y = self(data).main_out
nll_loss = self.nll_loss(y, data.y) nll_loss = self.nll_loss(y, data.y)
return dict(val_nll_loss=nll_loss, return dict(val_nll_loss=nll_loss,

View File

@ -16,7 +16,7 @@ from pyod.models.hbos import HBOS
from pyod.models.lscp import LSCP from pyod.models.lscp import LSCP
from pyod.models.feature_bagging import FeatureBagging from pyod.models.feature_bagging import FeatureBagging
from utils.project_config import Classes from utils.project_settings import Classes
def polytopes_to_planes(pc): def polytopes_to_planes(pc):

View File

@ -1,66 +1,4 @@
from argparse import Namespace
from ml_lib.utils.config import Config from ml_lib.utils.config import Config
class DataClass(Namespace):
def __len__(self):
return len(self.__dict__())
def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def items(self):
return self.__dict__().items()
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
def __getitem__(self, item):
return self.__getattribute__(item)
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Torus = 5
Plane = 6
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
devel = 'devel'
test = 'test'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()
from models import * from models import *

61
utils/project_settings.py Normal file
View File

@ -0,0 +1,61 @@
from argparse import Namespace
from ml_lib.utils.config import Config
class DataClass(Namespace):
def __len__(self):
return len(self.__dict__())
def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def items(self):
return self.__dict__().items()
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
def __getitem__(self, item):
return self.__getattribute__(item)
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Torus = 5
Plane = 6
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
devel = 'devel'
test = 'test'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()