New Model running

This commit is contained in:
Si11ium
2020-06-23 14:37:34 +02:00
parent a19bd9cafd
commit 1033b26195
12 changed files with 173 additions and 112 deletions
+8 -8
View File
@@ -110,10 +110,10 @@ class BaseValMixin:
#######################################################################################
#
# INIT
y_true = torch.cat([output['batch_y'] for output in outputs]) .cpu().numpy()
y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
y_true_one_hot = to_one_hot(y_true, self.n_classes)
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().numpy()
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in GlobalVar.classes.items()}
@@ -134,7 +134,7 @@ class BaseValMixin:
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(GlobalVar.classes)):
for i in range(self.n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
@@ -143,15 +143,15 @@ class BaseValMixin:
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(GlobalVar.classes))]))
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.n_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(len(GlobalVar.classes)):
for i in range(self.n_classes):
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= len(GlobalVar.classes)
mean_tpr /= self.n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
@@ -170,7 +170,7 @@ class BaseValMixin:
colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )
for i, color in zip(range(len(GlobalVar.classes)), colors):
for i, color in zip(range(self.n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')
plt.plot([0, 1], [0, 1], 'k--', lw=2)
@@ -236,7 +236,7 @@ class DatasetMixin:
**kwargs),
# TEST DATASET
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.test,
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
**kwargs),
)
)
+14 -16
View File
@@ -3,27 +3,15 @@ from sklearn.cluster import DBSCAN
import open3d as o3d
from pyod.models.knn import KNN
from pyod.models.sod import SOD
from pyod.models.abod import ABOD
from pyod.models.sos import SOS
from pyod.models.pca import PCA
from pyod.models.ocsvm import OCSVM
from pyod.models.mcd import MCD
from pyod.models.lof import LOF
from pyod.models.cof import COF
from pyod.models.cblof import CBLOF
from pyod.models.loci import LOCI
from pyod.models.hbos import HBOS
from pyod.models.lscp import LSCP
from pyod.models.feature_bagging import FeatureBagging
from torch_geometric.data import Data
from utils.project_settings import Classes
def polytopes_to_planes(pc):
pc[(pc[:, 6] == float(Classes.Box)) | (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane);
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
return pc
@@ -49,7 +37,7 @@ def mini_color_table(index, norm=True):
def cluster2Color(cluster, cluster_idx):
colors = np.zeros(shape=(len(cluster), 3))
point_idx = 0
for point in cluster:
for _ in cluster:
colors[point_idx, :] = mini_color_table(cluster_idx)
point_idx += 1
@@ -87,6 +75,8 @@ def write_pointcloud(file, pc, numCols=6):
def farthest_point_sampling(pts, K):
if isinstance(pts, Data):
pts = pts.pos.numpy()
if pts.shape[0] < K:
return pts
@@ -119,7 +109,15 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_
if isinstance(data, Data):
import torch
data = torch.cat((data.pos, data.norm, data.y.double().unsqueeze(-1)), dim=-1).numpy()
candidate_list = list()
if data.pos is not None:
candidate_list.append(data.pos)
if data.norm is not None:
candidate_list.append(data.norm)
if data.y is not None:
candidate_list.append(data.y.double().unsqueeze(-1))
data = torch.cat(candidate_list, dim=-1).numpy()
max = data[:, :3].max(axis=0)
max += max * 0.01
+11 -7
View File
@@ -1,7 +1,5 @@
from argparse import Namespace
from ml_lib.utils.config import Config
class DataClass(Namespace):
@@ -18,18 +16,19 @@ class DataClass(Namespace):
return f'{self.__class__.__name__}({self.__dict__().__repr__()})'
def __getitem__(self, item):
return self.__getattribute__(item)
return self.__dict__()[item]
class Classes(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Cone = 2
Box = 3
Polytope = 4
Box = 3 # All SubTypes of Planes
Polytope = 4 #
Torus = 5
Plane = 6
Plane = 6 #
class Settings(DataClass):
@@ -38,6 +37,11 @@ class Settings(DataClass):
PN2 = 'pc'
class ClusterTypes(DataClass):
prim = 'prim'
grid = 'grid'
none = ''
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
@@ -59,4 +63,4 @@ class GlobalVar(DataClass):
prim_count = -1
settings = Settings()
settings = Settings()