Classes Fixed abnd debugging

This commit is contained in:
Si11ium
2020-07-03 14:40:28 +02:00
parent e9d0591b11
commit 5353220890
10 changed files with 66 additions and 59 deletions

View File

@@ -15,14 +15,14 @@ import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam
from torch_geometric.data import Data, DataLoader
from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale
from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from .project_settings import GlobalVar
from utils.project_settings import dataSplit
class BaseOptimizerMixin:
@@ -116,7 +116,7 @@ class BaseValMixin:
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in GlobalVar.classes.items()}
class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()}
######################################################################################
#
# F1 SCORE
@@ -223,21 +223,30 @@ class DatasetMixin:
# Dataset
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose(
[
FixedPoints(8096),
NormalizeScale()
]
)
test_kwargs = kwargs.copy()
test_kwargs.update(transform=transforms)
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True,
**kwargs),
# VALIDATION DATASET
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
**kwargs),
val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False,
**test_kwargs),
# TEST DATASET
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
**kwargs),
test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False,
**test_kwargs),
)
)
return dataset

View File

@@ -1,17 +1,17 @@
import numpy as np
from sklearn.cluster import DBSCAN
#import open3d as o3d
# import open3d as o3d
from pyod.models.lof import LOF
from torch_geometric.data import Data
from utils.project_settings import Classes
from utils.project_settings import classesAll
def polytopes_to_planes(pc):
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
pc[(pc[:, 6] == float(classesAll.Box)) or (pc[:, 6] == float(classesAll.Polytope)), 6] = float(classesAll.Plane)
return pc

View File

@@ -19,8 +19,7 @@ class DataClass(Namespace):
return self.__dict__()[item]
class Classes(DataClass):
class ClassesALL(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
@@ -29,10 +28,11 @@ class Classes(DataClass):
Plane = 4 #
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class ClassesPolyAsPlane(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Plane = 2 # All SubTypes of Planes
class ClusterTypes(DataClass):
@@ -40,6 +40,7 @@ class ClusterTypes(DataClass):
grid = 'grid'
none = ''
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
@@ -47,18 +48,7 @@ class DataSplit(DataClass):
test = 'test'
predict = 'predict'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()
classesAll = ClassesALL()
classesPolyAsPlane= ClassesPolyAsPlane()
clusterTypes = ClusterTypes()
dataSplit=DataSplit()