Refactured Settings File
This commit is contained in:
@@ -22,7 +22,7 @@ from torchcontrib.optim import SWA
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.tools import to_one_hot
|
||||
|
||||
from .project_config import GlobalVar
|
||||
from .project_settings import GlobalVar
|
||||
|
||||
|
||||
class BaseOptimizerMixin:
|
||||
@@ -86,7 +86,7 @@ class BaseValMixin:
|
||||
|
||||
def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
data = self.batch_to_data(*batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
||||
data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
|
||||
y = self(data).main_out
|
||||
nll_loss = self.nll_loss(y, data.y)
|
||||
return dict(val_nll_loss=nll_loss,
|
||||
|
||||
@@ -16,7 +16,7 @@ from pyod.models.hbos import HBOS
|
||||
from pyod.models.lscp import LSCP
|
||||
from pyod.models.feature_bagging import FeatureBagging
|
||||
|
||||
from utils.project_config import Classes
|
||||
from utils.project_settings import Classes
|
||||
|
||||
|
||||
def polytopes_to_planes(pc):
|
||||
|
||||
@@ -1,66 +1,4 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
|
||||
class DataClass(Namespace):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__())
|
||||
|
||||
def __dict__(self):
|
||||
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
|
||||
|
||||
def items(self):
|
||||
return self.__dict__().items()
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__getattribute__(item)
|
||||
|
||||
|
||||
class Classes(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
Cone = 2
|
||||
Box = 3
|
||||
Polytope = 4
|
||||
Torus = 5
|
||||
Plane = 6
|
||||
|
||||
|
||||
class Settings(DataClass):
|
||||
P2G = 'grid'
|
||||
P2P = 'prim'
|
||||
PN2 = 'pc'
|
||||
|
||||
|
||||
class DataSplit(DataClass):
|
||||
# DATA SPLIT OPTIONS
|
||||
train = 'train'
|
||||
devel = 'devel'
|
||||
test = 'test'
|
||||
|
||||
|
||||
class GlobalVar(DataClass):
|
||||
# Variables for plotting
|
||||
PADDING = 0.25
|
||||
DPI = 50
|
||||
|
||||
data_split = DataSplit()
|
||||
|
||||
classes = Classes()
|
||||
|
||||
grid_count = 12
|
||||
|
||||
prim_count = -1
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
from models import *
|
||||
|
||||
|
||||
|
||||
61
utils/project_settings.py
Normal file
61
utils/project_settings.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
|
||||
class DataClass(Namespace):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.__dict__())
|
||||
|
||||
def __dict__(self):
|
||||
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
|
||||
|
||||
def items(self):
|
||||
return self.__dict__().items()
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__getattribute__(item)
|
||||
|
||||
|
||||
class Classes(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
Cone = 2
|
||||
Box = 3
|
||||
Polytope = 4
|
||||
Torus = 5
|
||||
Plane = 6
|
||||
|
||||
|
||||
class Settings(DataClass):
|
||||
P2G = 'grid'
|
||||
P2P = 'prim'
|
||||
PN2 = 'pc'
|
||||
|
||||
|
||||
class DataSplit(DataClass):
|
||||
# DATA SPLIT OPTIONS
|
||||
train = 'train'
|
||||
devel = 'devel'
|
||||
test = 'test'
|
||||
|
||||
|
||||
class GlobalVar(DataClass):
|
||||
# Variables for plotting
|
||||
PADDING = 0.25
|
||||
DPI = 50
|
||||
|
||||
data_split = DataSplit()
|
||||
|
||||
classes = Classes()
|
||||
|
||||
grid_count = 12
|
||||
|
||||
prim_count = -1
|
||||
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user