Classes Fixed abnd debugging
This commit is contained in:
@@ -15,14 +15,14 @@ import matplotlib.pyplot as plt
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from torch_geometric.data import Data, DataLoader
|
||||
from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale
|
||||
|
||||
from torchcontrib.optim import SWA
|
||||
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.tools import to_one_hot
|
||||
|
||||
from .project_settings import GlobalVar
|
||||
from utils.project_settings import dataSplit
|
||||
|
||||
|
||||
class BaseOptimizerMixin:
|
||||
@@ -116,7 +116,7 @@ class BaseValMixin:
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
y_pred_max = np.argmax(y_pred, axis=1)
|
||||
|
||||
class_names = {val: key for key, val in GlobalVar.classes.items()}
|
||||
class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()}
|
||||
######################################################################################
|
||||
#
|
||||
# F1 SCORE
|
||||
@@ -223,21 +223,30 @@ class DatasetMixin:
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
# Data Augmentations or Utility Transformations
|
||||
transforms = Compose(
|
||||
[
|
||||
FixedPoints(8096),
|
||||
|
||||
NormalizeScale()
|
||||
]
|
||||
)
|
||||
test_kwargs = kwargs.copy()
|
||||
test_kwargs.update(transform=transforms)
|
||||
# Dataset
|
||||
dataset = Namespace(
|
||||
**dict(
|
||||
# TRAIN DATASET
|
||||
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
|
||||
train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True,
|
||||
**kwargs),
|
||||
|
||||
|
||||
# VALIDATION DATASET
|
||||
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
|
||||
**kwargs),
|
||||
val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False,
|
||||
**test_kwargs),
|
||||
|
||||
# TEST DATASET
|
||||
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
|
||||
**kwargs),
|
||||
test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False,
|
||||
**test_kwargs),
|
||||
)
|
||||
)
|
||||
return dataset
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import numpy as np
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
#import open3d as o3d
|
||||
# import open3d as o3d
|
||||
|
||||
from pyod.models.lof import LOF
|
||||
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from utils.project_settings import Classes
|
||||
from utils.project_settings import classesAll
|
||||
|
||||
|
||||
def polytopes_to_planes(pc):
|
||||
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
|
||||
pc[(pc[:, 6] == float(classesAll.Box)) or (pc[:, 6] == float(classesAll.Polytope)), 6] = float(classesAll.Plane)
|
||||
return pc
|
||||
|
||||
|
||||
|
||||
@@ -19,8 +19,7 @@ class DataClass(Namespace):
|
||||
return self.__dict__()[item]
|
||||
|
||||
|
||||
class Classes(DataClass):
|
||||
|
||||
class ClassesALL(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
@@ -29,10 +28,11 @@ class Classes(DataClass):
|
||||
Plane = 4 #
|
||||
|
||||
|
||||
class Settings(DataClass):
|
||||
P2G = 'grid'
|
||||
P2P = 'prim'
|
||||
PN2 = 'pc'
|
||||
class ClassesPolyAsPlane(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
Plane = 2 # All SubTypes of Planes
|
||||
|
||||
|
||||
class ClusterTypes(DataClass):
|
||||
@@ -40,6 +40,7 @@ class ClusterTypes(DataClass):
|
||||
grid = 'grid'
|
||||
none = ''
|
||||
|
||||
|
||||
class DataSplit(DataClass):
|
||||
# DATA SPLIT OPTIONS
|
||||
train = 'train'
|
||||
@@ -47,18 +48,7 @@ class DataSplit(DataClass):
|
||||
test = 'test'
|
||||
predict = 'predict'
|
||||
|
||||
|
||||
class GlobalVar(DataClass):
|
||||
# Variables for plotting
|
||||
PADDING = 0.25
|
||||
DPI = 50
|
||||
|
||||
data_split = DataSplit()
|
||||
|
||||
classes = Classes()
|
||||
|
||||
grid_count = 12
|
||||
|
||||
prim_count = -1
|
||||
|
||||
settings = Settings()
|
||||
classesAll = ClassesALL()
|
||||
classesPolyAsPlane= ClassesPolyAsPlane()
|
||||
clusterTypes = ClusterTypes()
|
||||
dataSplit=DataSplit()
|
||||
Reference in New Issue
Block a user