Refactured Settings File
This commit is contained in:
parent
63605ae33a
commit
b79141e854
@ -5,16 +5,15 @@ import numpy as np
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import InMemoryDataset
|
from torch_geometric.data import InMemoryDataset
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import re
|
|
||||||
|
|
||||||
from utils.project_config import Classes, DataSplit
|
from utils.project_settings import Classes, DataSplit
|
||||||
|
|
||||||
|
|
||||||
def save_names(name_list, path):
|
def save_names(name_list, path):
|
||||||
@ -198,12 +197,13 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
|
|
||||||
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
|
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
|
||||||
|
|
||||||
sample = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
|
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
|
||||||
y=y, # torch.Tensor (n,)
|
y=y, # torch.Tensor (n,)
|
||||||
norm=norm # torch.Tensor (n, 3/0)
|
norm=norm # torch.Tensor (n, 3/0)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return sample
|
|
||||||
|
return data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
@ -16,7 +16,8 @@ from ml_lib.utils.model_io import SavedLightningModels
|
|||||||
|
|
||||||
# Datasets
|
# Datasets
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from utils.project_config import GlobalVar, ThisConfig
|
from utils.project_config import ThisConfig
|
||||||
|
from utils.project_settings import GlobalVar
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataloader(config_obj):
|
def prepare_dataloader(config_obj):
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import polyscope as ps
|
import polyscope as ps
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -20,8 +18,8 @@ from ml_lib.utils.model_io import SavedLightningModels
|
|||||||
# Datasets
|
# Datasets
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
|
||||||
label2color, polytopes_to_planes
|
label2color
|
||||||
from utils.project_config import GlobalVar, ThisConfig
|
from utils.project_settings import GlobalVar
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataloader(config_obj):
|
def prepare_dataloader(config_obj):
|
||||||
@ -56,14 +54,15 @@ def predict_prim_type(input_pc, model):
|
|||||||
|
|
||||||
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
|
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
input_pc_path = 'data/pc/pc.txt'
|
input_pc_path = Path('data') / 'pc' / 'pc.txt'
|
||||||
|
|
||||||
model_path = Path('trained_models/version_1')
|
model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0'
|
||||||
config_filename = 'config.ini'
|
# config_filename = 'config.ini'
|
||||||
config = ThisConfig()
|
# config = ThisConfig()
|
||||||
config.read_file((Path(model_path) / config_filename).open('r'))
|
# config.read_file((Path(model_path) / config_filename).open('r'))
|
||||||
loaded_model = restore_logger_and_model(model_path)
|
loaded_model = restore_logger_and_model(model_path)
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from datasets.shapenet import ShapeNetPartSegDataset
|
|||||||
from models._point_net_2 import _PointNetCore
|
from models._point_net_2 import _PointNetCore
|
||||||
|
|
||||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||||
from utils.project_config import GlobalVar
|
from utils.project_settings import GlobalVar
|
||||||
|
|
||||||
|
|
||||||
class PointNet2(BaseValMixin,
|
class PointNet2(BaseValMixin,
|
||||||
|
@ -22,7 +22,7 @@ from torchcontrib.optim import SWA
|
|||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.tools import to_one_hot
|
from ml_lib.utils.tools import to_one_hot
|
||||||
|
|
||||||
from .project_config import GlobalVar
|
from .project_settings import GlobalVar
|
||||||
|
|
||||||
|
|
||||||
class BaseOptimizerMixin:
|
class BaseOptimizerMixin:
|
||||||
@ -86,7 +86,7 @@ class BaseValMixin:
|
|||||||
|
|
||||||
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
|
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
||||||
y = self(data).main_out
|
y = self(data).main_out
|
||||||
nll_loss = self.nll_loss(y, data.y)
|
nll_loss = self.nll_loss(y, data.y)
|
||||||
return dict(val_nll_loss=nll_loss,
|
return dict(val_nll_loss=nll_loss,
|
||||||
|
@ -16,7 +16,7 @@ from pyod.models.hbos import HBOS
|
|||||||
from pyod.models.lscp import LSCP
|
from pyod.models.lscp import LSCP
|
||||||
from pyod.models.feature_bagging import FeatureBagging
|
from pyod.models.feature_bagging import FeatureBagging
|
||||||
|
|
||||||
from utils.project_config import Classes
|
from utils.project_settings import Classes
|
||||||
|
|
||||||
|
|
||||||
def polytopes_to_planes(pc):
|
def polytopes_to_planes(pc):
|
||||||
|
@ -1,66 +1,4 @@
|
|||||||
from argparse import Namespace
|
|
||||||
|
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
|
|
||||||
|
|
||||||
class DataClass(Namespace):
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.__dict__())
|
|
||||||
|
|
||||||
def __dict__(self):
|
|
||||||
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
return self.__dict__().items()
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self.__getattribute__(item)
|
|
||||||
|
|
||||||
|
|
||||||
class Classes(DataClass):
|
|
||||||
# Object Classes for Point Segmentation
|
|
||||||
Sphere = 0
|
|
||||||
Cylinder = 1
|
|
||||||
Cone = 2
|
|
||||||
Box = 3
|
|
||||||
Polytope = 4
|
|
||||||
Torus = 5
|
|
||||||
Plane = 6
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(DataClass):
|
|
||||||
P2G = 'grid'
|
|
||||||
P2P = 'prim'
|
|
||||||
PN2 = 'pc'
|
|
||||||
|
|
||||||
|
|
||||||
class DataSplit(DataClass):
|
|
||||||
# DATA SPLIT OPTIONS
|
|
||||||
train = 'train'
|
|
||||||
devel = 'devel'
|
|
||||||
test = 'test'
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalVar(DataClass):
|
|
||||||
# Variables for plotting
|
|
||||||
PADDING = 0.25
|
|
||||||
DPI = 50
|
|
||||||
|
|
||||||
data_split = DataSplit()
|
|
||||||
|
|
||||||
classes = Classes()
|
|
||||||
|
|
||||||
grid_count = 12
|
|
||||||
|
|
||||||
prim_count = -1
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
|
|
||||||
from models import *
|
from models import *
|
||||||
|
|
||||||
|
|
||||||
|
61
utils/project_settings.py
Normal file
61
utils/project_settings.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from ml_lib.utils.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
class DataClass(Namespace):
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.__dict__())
|
||||||
|
|
||||||
|
def __dict__(self):
|
||||||
|
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return self.__dict__().items()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__getattribute__(item)
|
||||||
|
|
||||||
|
|
||||||
|
class Classes(DataClass):
|
||||||
|
# Object Classes for Point Segmentation
|
||||||
|
Sphere = 0
|
||||||
|
Cylinder = 1
|
||||||
|
Cone = 2
|
||||||
|
Box = 3
|
||||||
|
Polytope = 4
|
||||||
|
Torus = 5
|
||||||
|
Plane = 6
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(DataClass):
|
||||||
|
P2G = 'grid'
|
||||||
|
P2P = 'prim'
|
||||||
|
PN2 = 'pc'
|
||||||
|
|
||||||
|
|
||||||
|
class DataSplit(DataClass):
|
||||||
|
# DATA SPLIT OPTIONS
|
||||||
|
train = 'train'
|
||||||
|
devel = 'devel'
|
||||||
|
test = 'test'
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalVar(DataClass):
|
||||||
|
# Variables for plotting
|
||||||
|
PADDING = 0.25
|
||||||
|
DPI = 50
|
||||||
|
|
||||||
|
data_split = DataSplit()
|
||||||
|
|
||||||
|
classes = Classes()
|
||||||
|
|
||||||
|
grid_count = 12
|
||||||
|
|
||||||
|
prim_count = -1
|
||||||
|
|
||||||
|
settings = Settings()
|
Loading…
x
Reference in New Issue
Block a user