Classes Fixed abnd debugging
This commit is contained in:
parent
e9d0591b11
commit
5353220890
@ -25,7 +25,7 @@ main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
|
|||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
||||||
main_arg_parser.add_argument("--data_cluster_type", type=str, default='prim', help="")
|
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
||||||
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="")
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import torch
|
|||||||
from torch_geometric.data import InMemoryDataset
|
from torch_geometric.data import InMemoryDataset
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
from utils.project_settings import Classes, DataSplit, ClusterTypes
|
from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes
|
||||||
|
|
||||||
|
|
||||||
def save_names(name_list, path):
|
def save_names(name_list, path):
|
||||||
@ -34,11 +34,11 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def modes(self):
|
def modes(self):
|
||||||
return {key: val for val, key in DataSplit().items()}
|
return {key: val for val, key in dataSplit.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cluster_types(self):
|
def cluster_types(self):
|
||||||
return {key: val for val, key in ClusterTypes().items()}
|
return {key: val for val, key in clusterTypes.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_dir(self):
|
def raw_dir(self):
|
||||||
@ -62,8 +62,8 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
# Set the Dataset Parameters
|
# Set the Dataset Parameters
|
||||||
self.cluster_type = cluster_type if cluster_type else 'pc'
|
self.cluster_type = cluster_type if cluster_type else 'pc'
|
||||||
self.classes = Classes()
|
|
||||||
self.poly_as_plane = poly_as_plane
|
self.poly_as_plane = poly_as_plane
|
||||||
|
self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane
|
||||||
self.collate_per_segment = collate_per_segment
|
self.collate_per_segment = collate_per_segment
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
@ -92,10 +92,10 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
|
return len(self.categories)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def class_map_all(self):
|
def _class_map_all(self):
|
||||||
return {0: 0,
|
return {0: 0,
|
||||||
1: 1,
|
1: 1,
|
||||||
2: None,
|
2: None,
|
||||||
@ -107,7 +107,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def class_map_poly_as_plane(self):
|
def _class_map_poly_as_plane(self):
|
||||||
return {0: 0,
|
return {0: 0,
|
||||||
1: 1,
|
1: 1,
|
||||||
2: None,
|
2: None,
|
||||||
@ -118,11 +118,15 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
7: None
|
7: None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_remap(self):
|
||||||
|
return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane
|
||||||
|
|
||||||
def _load_dataset(self):
|
def _load_dataset(self):
|
||||||
data, slices = None, None
|
data, slices = None, None
|
||||||
filepath = self.processed_paths[0]
|
filepath = self.processed_paths[0]
|
||||||
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
|
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
|
||||||
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
|
if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict:
|
||||||
with config_path.open('rb') as f:
|
with config_path.open('rb') as f:
|
||||||
config = pickle.load(f)
|
config = pickle.load(f)
|
||||||
if config == self._build_config():
|
if config == self._build_config():
|
||||||
@ -155,7 +159,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
break
|
break
|
||||||
self.process()
|
self.process()
|
||||||
continue
|
continue
|
||||||
if not self.mode == DataSplit().predict:
|
if not self.mode == dataSplit.predict:
|
||||||
config = self._build_config()
|
config = self._build_config()
|
||||||
with config_path.open('wb') as f:
|
with config_path.open('wb') as f:
|
||||||
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
|
||||||
@ -178,7 +182,6 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
datasets = defaultdict(list)
|
datasets = defaultdict(list)
|
||||||
path_to_clouds = self.raw_dir / self.mode
|
path_to_clouds = self.raw_dir / self.mode
|
||||||
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
||||||
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
|
|
||||||
if len(found_clouds):
|
if len(found_clouds):
|
||||||
for pointcloud in tqdm(found_clouds):
|
for pointcloud in tqdm(found_clouds):
|
||||||
if self.cluster_type not in pointcloud.name:
|
if self.cluster_type not in pointcloud.name:
|
||||||
@ -196,21 +199,28 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
raise ValueError('Check the Input!!!!!!')
|
raise ValueError('Check the Input!!!!!!')
|
||||||
# Expand the values from the csv by fake labels if non are provided.
|
# Expand the values from the csv by fake labels if non are provided.
|
||||||
vals = vals + [0] * (8 - len(vals))
|
vals = vals + [0] * (8 - len(vals))
|
||||||
vals[-2] = float(class_map[int(vals[-2])])
|
vals[-2] = float(self.class_remap[int(vals[-2])])
|
||||||
src[vals[-1]].append(vals)
|
src[vals[-1]].append(vals)
|
||||||
|
|
||||||
# Switch from un-pickable Defaultdict to Standard Dict
|
# Switch from un-pickable Defaultdict to Standard Dict
|
||||||
src = dict(src)
|
src = dict(src)
|
||||||
|
|
||||||
# Transform the Dict[List] to Dict[torch.Tensor]
|
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||||
for key, values in src.items():
|
for key, values in list(src.items()):
|
||||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||||
|
if src[key].ndim == 2:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
del src[key]
|
||||||
|
|
||||||
# Screw the Sorting and make it a FullCloud rather than a seperated
|
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||||
if not self.collate_per_segment:
|
if not self.collate_per_segment:
|
||||||
|
try:
|
||||||
src = dict(
|
src = dict(
|
||||||
all=torch.cat(tuple(src.values()))
|
all=torch.cat(tuple(src.values()))
|
||||||
)
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
print('debugg')
|
||||||
|
|
||||||
# Transform Box and Polytope to Plane if poly_as_plane is set
|
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||||
for key, tensor in src.items():
|
for key, tensor in src.items():
|
||||||
@ -274,6 +284,7 @@ class ShapeNetPartSegDataset(Dataset):
|
|||||||
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
||||||
# self.npoints = npoints
|
# self.npoints = npoints
|
||||||
self.dataset = CustomShapeNet(**kwargs)
|
self.dataset = CustomShapeNet(**kwargs)
|
||||||
|
self.classes = self.dataset.classes
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
data = self.dataset[index]
|
data = self.dataset[index]
|
||||||
|
@ -17,7 +17,6 @@ from ml_lib.utils.model_io import SavedLightningModels
|
|||||||
# Datasets
|
# Datasets
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from utils.project_config import ThisConfig
|
from utils.project_config import ThisConfig
|
||||||
from utils.project_settings import GlobalVar
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataloader(config_obj):
|
def prepare_dataloader(config_obj):
|
||||||
|
@ -63,16 +63,16 @@ if __name__ == '__main__':
|
|||||||
type_cluster_eps = 0.1
|
type_cluster_eps = 0.1
|
||||||
type_cluster_min_pts = 100
|
type_cluster_min_pts = 100
|
||||||
|
|
||||||
model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0'
|
model_path = Path('output') / 'PN2' / 'PN_597ecab330b04d977cda8a09ae0e5f6e' / 'version_0'
|
||||||
|
|
||||||
loaded_model = restore_logger_and_model(model_path)
|
loaded_model = restore_logger_and_model(model_path)
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
|
|
||||||
transforms = Compose([NormalizeScale(), ])
|
transforms = Compose([NormalizeScale(), ])
|
||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms) # , cluster_type='pc')
|
refresh=True, transform=transforms, cluster_type=None)
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||||
|
|
||||||
ps.init()
|
ps.init()
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ from abc import ABC
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
|
||||||
|
|
||||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
||||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||||
|
@ -8,7 +8,6 @@ from datasets.shapenet import ShapeNetPartSegDataset
|
|||||||
from models._point_net_2 import _PointNetCore
|
from models._point_net_2 import _PointNetCore
|
||||||
|
|
||||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||||
from utils.project_settings import GlobalVar
|
|
||||||
|
|
||||||
|
|
||||||
class PointNet2(BaseValMixin,
|
class PointNet2(BaseValMixin,
|
||||||
@ -33,7 +32,7 @@ class PointNet2(BaseValMixin,
|
|||||||
# This is not available with 6-dim cords
|
# This is not available with 6-dim cords
|
||||||
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
|
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
|
||||||
RandomTranslate(trans_max_distance),
|
RandomTranslate(trans_max_distance),
|
||||||
NormalizeScale()
|
# NormalizeScale()
|
||||||
# NormalizePositions()
|
# NormalizePositions()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -41,7 +40,6 @@ class PointNet2(BaseValMixin,
|
|||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
||||||
collate_per_segment=True,
|
|
||||||
transform=transforms,
|
transform=transforms,
|
||||||
cluster_type=self.params.cluster_type,
|
cluster_type=self.params.cluster_type,
|
||||||
refresh=self.params.refresh,
|
refresh=self.params.refresh,
|
||||||
@ -51,7 +49,7 @@ class PointNet2(BaseValMixin,
|
|||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
|
self.n_classes = len(self.dataset.train_dataset.classes)
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
||||||
|
@ -26,8 +26,9 @@ if __name__ == '__main__':
|
|||||||
for poly_as_plane in [True, False]:
|
for poly_as_plane in [True, False]:
|
||||||
for normals_as_cords in [True, False]:
|
for normals_as_cords in [True, False]:
|
||||||
arg_dict.update(main_seed=seed,
|
arg_dict.update(main_seed=seed,
|
||||||
normals_as_cords=normals_as_cords, poly_as_plane=poly_as_plane)
|
data_normals_as_cords=normals_as_cords,
|
||||||
|
data_poly_as_plane=poly_as_plane
|
||||||
|
)
|
||||||
|
|
||||||
config = config.update(arg_dict)
|
config = config.update(arg_dict)
|
||||||
|
|
||||||
run_lightning_loop(config)
|
run_lightning_loop(config)
|
||||||
|
@ -15,14 +15,14 @@ import matplotlib.pyplot as plt
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch_geometric.data import Data, DataLoader
|
from torch_geometric.data import Data, DataLoader
|
||||||
|
from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale
|
||||||
|
|
||||||
from torchcontrib.optim import SWA
|
from torchcontrib.optim import SWA
|
||||||
|
|
||||||
|
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from ml_lib.utils.tools import to_one_hot
|
from ml_lib.utils.tools import to_one_hot
|
||||||
|
from utils.project_settings import dataSplit
|
||||||
from .project_settings import GlobalVar
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOptimizerMixin:
|
class BaseOptimizerMixin:
|
||||||
@ -116,7 +116,7 @@ class BaseValMixin:
|
|||||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||||
y_pred_max = np.argmax(y_pred, axis=1)
|
y_pred_max = np.argmax(y_pred, axis=1)
|
||||||
|
|
||||||
class_names = {val: key for key, val in GlobalVar.classes.items()}
|
class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()}
|
||||||
######################################################################################
|
######################################################################################
|
||||||
#
|
#
|
||||||
# F1 SCORE
|
# F1 SCORE
|
||||||
@ -223,21 +223,30 @@ class DatasetMixin:
|
|||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Data Augmentations or Utility Transformations
|
# Data Augmentations or Utility Transformations
|
||||||
|
transforms = Compose(
|
||||||
|
[
|
||||||
|
FixedPoints(8096),
|
||||||
|
|
||||||
|
NormalizeScale()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
test_kwargs = kwargs.copy()
|
||||||
|
test_kwargs.update(transform=transforms)
|
||||||
# Dataset
|
# Dataset
|
||||||
dataset = Namespace(
|
dataset = Namespace(
|
||||||
**dict(
|
**dict(
|
||||||
# TRAIN DATASET
|
# TRAIN DATASET
|
||||||
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
|
train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True,
|
||||||
**kwargs),
|
**kwargs),
|
||||||
|
|
||||||
|
|
||||||
# VALIDATION DATASET
|
# VALIDATION DATASET
|
||||||
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
|
val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False,
|
||||||
**kwargs),
|
**test_kwargs),
|
||||||
|
|
||||||
# TEST DATASET
|
# TEST DATASET
|
||||||
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
|
test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False,
|
||||||
**kwargs),
|
**test_kwargs),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.cluster import DBSCAN
|
from sklearn.cluster import DBSCAN
|
||||||
|
|
||||||
#import open3d as o3d
|
# import open3d as o3d
|
||||||
|
|
||||||
from pyod.models.lof import LOF
|
from pyod.models.lof import LOF
|
||||||
|
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
from utils.project_settings import Classes
|
from utils.project_settings import classesAll
|
||||||
|
|
||||||
|
|
||||||
def polytopes_to_planes(pc):
|
def polytopes_to_planes(pc):
|
||||||
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
|
pc[(pc[:, 6] == float(classesAll.Box)) or (pc[:, 6] == float(classesAll.Polytope)), 6] = float(classesAll.Plane)
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +19,7 @@ class DataClass(Namespace):
|
|||||||
return self.__dict__()[item]
|
return self.__dict__()[item]
|
||||||
|
|
||||||
|
|
||||||
class Classes(DataClass):
|
class ClassesALL(DataClass):
|
||||||
|
|
||||||
# Object Classes for Point Segmentation
|
# Object Classes for Point Segmentation
|
||||||
Sphere = 0
|
Sphere = 0
|
||||||
Cylinder = 1
|
Cylinder = 1
|
||||||
@ -29,10 +28,11 @@ class Classes(DataClass):
|
|||||||
Plane = 4 #
|
Plane = 4 #
|
||||||
|
|
||||||
|
|
||||||
class Settings(DataClass):
|
class ClassesPolyAsPlane(DataClass):
|
||||||
P2G = 'grid'
|
# Object Classes for Point Segmentation
|
||||||
P2P = 'prim'
|
Sphere = 0
|
||||||
PN2 = 'pc'
|
Cylinder = 1
|
||||||
|
Plane = 2 # All SubTypes of Planes
|
||||||
|
|
||||||
|
|
||||||
class ClusterTypes(DataClass):
|
class ClusterTypes(DataClass):
|
||||||
@ -40,6 +40,7 @@ class ClusterTypes(DataClass):
|
|||||||
grid = 'grid'
|
grid = 'grid'
|
||||||
none = ''
|
none = ''
|
||||||
|
|
||||||
|
|
||||||
class DataSplit(DataClass):
|
class DataSplit(DataClass):
|
||||||
# DATA SPLIT OPTIONS
|
# DATA SPLIT OPTIONS
|
||||||
train = 'train'
|
train = 'train'
|
||||||
@ -47,18 +48,7 @@ class DataSplit(DataClass):
|
|||||||
test = 'test'
|
test = 'test'
|
||||||
predict = 'predict'
|
predict = 'predict'
|
||||||
|
|
||||||
|
classesAll = ClassesALL()
|
||||||
class GlobalVar(DataClass):
|
classesPolyAsPlane= ClassesPolyAsPlane()
|
||||||
# Variables for plotting
|
clusterTypes = ClusterTypes()
|
||||||
PADDING = 0.25
|
dataSplit=DataSplit()
|
||||||
DPI = 50
|
|
||||||
|
|
||||||
data_split = DataSplit()
|
|
||||||
|
|
||||||
classes = Classes()
|
|
||||||
|
|
||||||
grid_count = 12
|
|
||||||
|
|
||||||
prim_count = -1
|
|
||||||
|
|
||||||
settings = Settings()
|
|
Loading…
x
Reference in New Issue
Block a user