New Model running
This commit is contained in:
@@ -110,10 +110,10 @@ class BaseValMixin:
|
||||
#######################################################################################
|
||||
#
|
||||
# INIT
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
|
||||
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
|
||||
y_true_one_hot = to_one_hot(y_true, self.n_classes)
|
||||
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
y_pred_max = np.argmax(y_pred, axis=1)
|
||||
|
||||
class_names = {val: key for key, val in GlobalVar.classes.items()}
|
||||
@@ -134,7 +134,7 @@ class BaseValMixin:
|
||||
fpr = dict()
|
||||
tpr = dict()
|
||||
roc_auc = dict()
|
||||
for i in range(len(GlobalVar.classes)):
|
||||
for i in range(self.n_classes):
|
||||
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
|
||||
roc_auc[i] = auc(fpr[i], tpr[i])
|
||||
|
||||
@@ -143,15 +143,15 @@ class BaseValMixin:
|
||||
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
|
||||
|
||||
# First aggregate all false positive rates
|
||||
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(GlobalVar.classes))]))
|
||||
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.n_classes)]))
|
||||
|
||||
# Then interpolate all ROC curves at this points
|
||||
mean_tpr = np.zeros_like(all_fpr)
|
||||
for i in range(len(GlobalVar.classes)):
|
||||
for i in range(self.n_classes):
|
||||
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
|
||||
|
||||
# Finally average it and compute AUC
|
||||
mean_tpr /= len(GlobalVar.classes)
|
||||
mean_tpr /= self.n_classes
|
||||
|
||||
fpr["macro"] = all_fpr
|
||||
tpr["macro"] = mean_tpr
|
||||
@@ -170,7 +170,7 @@ class BaseValMixin:
|
||||
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
|
||||
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
|
||||
|
||||
for i, color in zip(range(len(GlobalVar.classes)), colors):
|
||||
for i, color in zip(range(self.n_classes), colors):
|
||||
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')
|
||||
|
||||
plt.plot([0, 1], [0, 1], 'k--', lw=2)
|
||||
@@ -236,7 +236,7 @@ class DatasetMixin:
|
||||
**kwargs),
|
||||
|
||||
# TEST DATASET
|
||||
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.test,
|
||||
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
|
||||
**kwargs),
|
||||
)
|
||||
)
|
||||
|
||||
+14
-16
@@ -3,27 +3,15 @@ from sklearn.cluster import DBSCAN
|
||||
|
||||
import open3d as o3d
|
||||
|
||||
from pyod.models.knn import KNN
|
||||
from pyod.models.sod import SOD
|
||||
from pyod.models.abod import ABOD
|
||||
from pyod.models.sos import SOS
|
||||
from pyod.models.pca import PCA
|
||||
from pyod.models.ocsvm import OCSVM
|
||||
from pyod.models.mcd import MCD
|
||||
from pyod.models.lof import LOF
|
||||
from pyod.models.cof import COF
|
||||
from pyod.models.cblof import CBLOF
|
||||
from pyod.models.loci import LOCI
|
||||
from pyod.models.hbos import HBOS
|
||||
from pyod.models.lscp import LSCP
|
||||
from pyod.models.feature_bagging import FeatureBagging
|
||||
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from utils.project_settings import Classes
|
||||
|
||||
|
||||
def polytopes_to_planes(pc):
|
||||
pc[(pc[:, 6] == float(Classes.Box)) | (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane);
|
||||
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
|
||||
return pc
|
||||
|
||||
|
||||
@@ -49,7 +37,7 @@ def mini_color_table(index, norm=True):
|
||||
def cluster2Color(cluster, cluster_idx):
|
||||
colors = np.zeros(shape=(len(cluster), 3))
|
||||
point_idx = 0
|
||||
for point in cluster:
|
||||
for _ in cluster:
|
||||
colors[point_idx, :] = mini_color_table(cluster_idx)
|
||||
point_idx += 1
|
||||
|
||||
@@ -87,6 +75,8 @@ def write_pointcloud(file, pc, numCols=6):
|
||||
|
||||
|
||||
def farthest_point_sampling(pts, K):
|
||||
if isinstance(pts, Data):
|
||||
pts = pts.pos.numpy()
|
||||
if pts.shape[0] < K:
|
||||
return pts
|
||||
|
||||
@@ -119,7 +109,15 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_
|
||||
|
||||
if isinstance(data, Data):
|
||||
import torch
|
||||
data = torch.cat((data.pos, data.norm, data.y.double().unsqueeze(-1)), dim=-1).numpy()
|
||||
candidate_list = list()
|
||||
if data.pos is not None:
|
||||
candidate_list.append(data.pos)
|
||||
if data.norm is not None:
|
||||
candidate_list.append(data.norm)
|
||||
if data.y is not None:
|
||||
candidate_list.append(data.y.double().unsqueeze(-1))
|
||||
|
||||
data = torch.cat(candidate_list, dim=-1).numpy()
|
||||
|
||||
max = data[:, :3].max(axis=0)
|
||||
max += max * 0.01
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
|
||||
class DataClass(Namespace):
|
||||
|
||||
@@ -18,18 +16,19 @@ class DataClass(Namespace):
|
||||
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__getattribute__(item)
|
||||
return self.__dict__()[item]
|
||||
|
||||
|
||||
class Classes(DataClass):
|
||||
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
Cone = 2
|
||||
Box = 3
|
||||
Polytope = 4
|
||||
Box = 3 # All SubTypes of Planes
|
||||
Polytope = 4 #
|
||||
Torus = 5
|
||||
Plane = 6
|
||||
Plane = 6 #
|
||||
|
||||
|
||||
class Settings(DataClass):
|
||||
@@ -38,6 +37,11 @@ class Settings(DataClass):
|
||||
PN2 = 'pc'
|
||||
|
||||
|
||||
class ClusterTypes(DataClass):
|
||||
prim = 'prim'
|
||||
grid = 'grid'
|
||||
none = ''
|
||||
|
||||
class DataSplit(DataClass):
|
||||
# DATA SPLIT OPTIONS
|
||||
train = 'train'
|
||||
@@ -59,4 +63,4 @@ class GlobalVar(DataClass):
|
||||
|
||||
prim_count = -1
|
||||
|
||||
settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
Reference in New Issue
Block a user