61 lines
1.1 KiB
Python
61 lines
1.1 KiB
Python
from argparse import Namespace
|
|
|
|
from ml_lib.utils.config import Config
|
|
|
|
|
|
class DataClass(Namespace):
|
|
|
|
def __len__(self):
|
|
return len(self.__dict__())
|
|
|
|
def __dict__(self):
|
|
return {key: val for key, val in self.__class__.__dict__.items() if '__' not in key}
|
|
|
|
def items(self):
|
|
return self.__dict__().items()
|
|
|
|
def __repr__(self):
|
|
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
|
|
|
|
|
class Classes(DataClass):
|
|
# Object Classes for Point Segmentation
|
|
Sphere = 0
|
|
Cylinder = 1
|
|
Cone = 2
|
|
Box = 3
|
|
Polytope = 4
|
|
Torus = 5
|
|
Plane = 6
|
|
|
|
|
|
class DataSplit(DataClass):
|
|
# DATA SPLIT OPTIONS
|
|
train = 'train'
|
|
devel = 'devel'
|
|
test = 'test'
|
|
|
|
|
|
class GlobalVar(DataClass):
|
|
# Variables for plotting
|
|
PADDING = 0.25
|
|
DPI = 50
|
|
|
|
data_split = DataSplit()
|
|
|
|
classes = Classes()
|
|
|
|
grid_count = 12
|
|
|
|
prim_count = -1
|
|
|
|
|
|
from models import *
|
|
|
|
|
|
class ThisConfig(Config):
|
|
|
|
@property
|
|
def _model_map(self):
|
|
return dict(PN2=PointNet2, P2P=PointNet2PrimClusters, P2G=PointNet2GridClusters)
|