point_to_primitive/utils/project_config.py
2020-06-07 17:30:47 +02:00

72 lines
1.2 KiB
Python

from argparse import Namespace
from ml_lib.utils.config import Config
class DataClass(Namespace):
def __len__(self):
return len(self.__dict__())
def __dict__(self):
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
def items(self):
return self.__dict__().items()
def __repr__(self):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
def __getitem__(self, item):
return self.__getattribute__(item)
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Torus = 5
Plane = 6
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
devel = 'devel'
test = 'test'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()
from models import *
class ThisConfig(Config):
@property
def _model_map(self):
return dict(PN2=PointNet2, P2P=PointNet2PrimClusters, P2G=PointNet2GridClusters)