Refactured Settings File

This commit is contained in:
Si11ium 2020-06-19 09:39:14 +02:00
parent 63605ae33a
commit b79141e854
9 changed files with 80 additions and 81 deletions

View File

@ -5,16 +5,15 @@ import numpy as np
from collections import defaultdict
import os
from torch.utils.data import Dataset
from tqdm import tqdm
import glob
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from torch.utils.data import Dataset
import re
from utils.project_config import Classes, DataSplit
from utils.project_settings import Classes, DataSplit
def save_names(name_list, path):
@ -198,12 +197,13 @@ class ShapeNetPartSegDataset(Dataset):
# y -= 1 if self.num_classes() in y else 0 # Map label from [1, C] to [0, C-1]
sample = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
data = Data(**dict(pos=pos, # torch.Tensor (n, 3/6)
y=y, # torch.Tensor (n,)
norm=norm # torch.Tensor (n, 3/0)
)
)
return sample
return data
def __len__(self):
return len(self.dataset)

View File

@ -16,7 +16,8 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets
from datasets.shapenet import ShapeNetPartSegDataset
from utils.project_config import GlobalVar, ThisConfig
from utils.project_config import ThisConfig
from utils.project_settings import GlobalVar
def prepare_dataloader(config_obj):

View File

@ -1,8 +1,6 @@
from pathlib import Path
import torch
from torch_geometric.data import Data
from tqdm import tqdm
import polyscope as ps
import numpy as np
@ -20,8 +18,8 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets
from datasets.shapenet import ShapeNetPartSegDataset
from utils.pointcloud import read_pointcloud, normalize_pointcloud, cluster_cubes, append_onehotencoded_type, \
label2color, polytopes_to_planes
from utils.project_config import GlobalVar, ThisConfig
label2color
from utils.project_settings import GlobalVar
def prepare_dataloader(config_obj):
@ -56,14 +54,15 @@ def predict_prim_type(input_pc, model):
return np.concatenate((input_pc, y_primary.reshape(-1,1)), axis=1)
if __name__ == '__main__':
input_pc_path = 'data/pc/pc.txt'
input_pc_path = Path('data') / 'pc' / 'pc.txt'
model_path = Path('trained_models/version_1')
config_filename = 'config.ini'
config = ThisConfig()
config.read_file((Path(model_path) / config_filename).open('r'))
model_path = Path('output') / 'PN2' / 'PN_26512907a2de0664bfad2349a6bffee3' / 'version_0'
# config_filename = 'config.ini'
# config = ThisConfig()
# config.read_file((Path(model_path) / config_filename).open('r'))
loaded_model = restore_logger_and_model(model_path)
loaded_model.eval()

View File

@ -7,7 +7,7 @@ from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_config import GlobalVar
from utils.project_settings import GlobalVar
class PointNet2(BaseValMixin,

View File

@ -22,7 +22,7 @@ from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from .project_config import GlobalVar
from .project_settings import GlobalVar
class BaseOptimizerMixin:
@ -86,7 +86,7 @@ class BaseValMixin:
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
assert isinstance(self, LightningBaseModule)
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
y = self(data).main_out
nll_loss = self.nll_loss(y, data.y)
return dict(val_nll_loss=nll_loss,

View File

@ -16,7 +16,7 @@ from pyod.models.hbos import HBOS
from pyod.models.lscp import LSCP
from pyod.models.feature_bagging import FeatureBagging
from utils.project_config import Classes
from utils.project_settings import Classes
def polytopes_to_planes(pc):

View File

@ -1,66 +1,4 @@
from argparse import Namespace
from ml_lib.utils.config import Config
class DataClass(Namespace):
def __len__(self):
return len(self.__dict__())
def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def items(self):
return self.__dict__().items()
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
def __getitem__(self, item):
return self.__getattribute__(item)
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Torus = 5
Plane = 6
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
devel = 'devel'
test = 'test'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()
from models import *

61
utils/project_settings.py Normal file
View File

@ -0,0 +1,61 @@
from argparse import Namespace
from ml_lib.utils.config import Config
class DataClass(Namespace):
def __len__(self):
return len(self.__dict__())
def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def items(self):
return self.__dict__().items()
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
def __getitem__(self, item):
return self.__getattribute__(item)
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Torus = 5
Plane = 6
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
devel = 'devel'
test = 'test'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()