Classes Fixed abnd debugging
This commit is contained in:
parent
e9d0591b11
commit
5353220890
@ -25,7 +25,7 @@ main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
|
||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
|
||||
main_arg_parser.add_argument("--data_cluster_type", type=str, default='prim', help="")
|
||||
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
|
||||
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="")
|
||||
|
||||
|
@ -13,7 +13,7 @@ import torch
|
||||
from torch_geometric.data import InMemoryDataset
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from utils.project_settings import Classes, DataSplit, ClusterTypes
|
||||
from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes
|
||||
|
||||
|
||||
def save_names(name_list, path):
|
||||
@ -34,11 +34,11 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
@property
|
||||
def modes(self):
|
||||
return {key: val for val, key in DataSplit().items()}
|
||||
return {key: val for val, key in dataSplit.items()}
|
||||
|
||||
@property
|
||||
def cluster_types(self):
|
||||
return {key: val for val, key in ClusterTypes().items()}
|
||||
return {key: val for val, key in clusterTypes.items()}
|
||||
|
||||
@property
|
||||
def raw_dir(self):
|
||||
@ -62,8 +62,8 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
# Set the Dataset Parameters
|
||||
self.cluster_type = cluster_type if cluster_type else 'pc'
|
||||
self.classes = Classes()
|
||||
self.poly_as_plane = poly_as_plane
|
||||
self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane
|
||||
self.collate_per_segment = collate_per_segment
|
||||
self.mode = mode
|
||||
self.refresh = refresh
|
||||
@ -92,10 +92,10 @@ class CustomShapeNet(InMemoryDataset):
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
|
||||
return len(self.categories)
|
||||
|
||||
@property
|
||||
def class_map_all(self):
|
||||
def _class_map_all(self):
|
||||
return {0: 0,
|
||||
1: 1,
|
||||
2: None,
|
||||
@ -107,7 +107,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
}
|
||||
|
||||
@property
|
||||
def class_map_poly_as_plane(self):
|
||||
def _class_map_poly_as_plane(self):
|
||||
return {0: 0,
|
||||
1: 1,
|
||||
2: None,
|
||||
@ -118,11 +118,15 @@ class CustomShapeNet(InMemoryDataset):
|
||||
7: None
|
||||
}
|
||||
|
||||
@property
|
||||
def class_remap(self):
|
||||
return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane
|
||||
|
||||
def _load_dataset(self):
|
||||
data, slices = None, None
|
||||
filepath = self.processed_paths[0]
|
||||
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
|
||||
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
|
||||
if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict:
|
||||
with config_path.open('rb') as f:
|
||||
config = pickle.load(f)
|
||||
if config == self._build_config():
|
||||
@ -155,7 +159,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
break
|
||||
self.process()
|
||||
continue
|
||||
if not self.mode == DataSplit().predict:
|
||||
if not self.mode == dataSplit.predict:
|
||||
config = self._build_config()
|
||||
with config_path.open('wb') as f:
|
||||
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
|
||||
@ -178,7 +182,6 @@ class CustomShapeNet(InMemoryDataset):
|
||||
datasets = defaultdict(list)
|
||||
path_to_clouds = self.raw_dir / self.mode
|
||||
found_clouds = list(path_to_clouds.glob('*.xyz'))
|
||||
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
|
||||
if len(found_clouds):
|
||||
for pointcloud in tqdm(found_clouds):
|
||||
if self.cluster_type not in pointcloud.name:
|
||||
@ -196,21 +199,28 @@ class CustomShapeNet(InMemoryDataset):
|
||||
raise ValueError('Check the Input!!!!!!')
|
||||
# Expand the values from the csv by fake labels if non are provided.
|
||||
vals = vals + [0] * (8 - len(vals))
|
||||
vals[-2] = float(class_map[int(vals[-2])])
|
||||
vals[-2] = float(self.class_remap[int(vals[-2])])
|
||||
src[vals[-1]].append(vals)
|
||||
|
||||
# Switch from un-pickable Defaultdict to Standard Dict
|
||||
src = dict(src)
|
||||
|
||||
# Transform the Dict[List] to Dict[torch.Tensor]
|
||||
for key, values in src.items():
|
||||
for key, values in list(src.items()):
|
||||
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
|
||||
if src[key].ndim == 2:
|
||||
pass
|
||||
else:
|
||||
del src[key]
|
||||
|
||||
# Screw the Sorting and make it a FullCloud rather than a seperated
|
||||
if not self.collate_per_segment:
|
||||
src = dict(
|
||||
all=torch.cat(tuple(src.values()))
|
||||
)
|
||||
try:
|
||||
src = dict(
|
||||
all=torch.cat(tuple(src.values()))
|
||||
)
|
||||
except RuntimeError:
|
||||
print('debugg')
|
||||
|
||||
# Transform Box and Polytope to Plane if poly_as_plane is set
|
||||
for key, tensor in src.items():
|
||||
@ -274,6 +284,7 @@ class ShapeNetPartSegDataset(Dataset):
|
||||
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
|
||||
# self.npoints = npoints
|
||||
self.dataset = CustomShapeNet(**kwargs)
|
||||
self.classes = self.dataset.classes
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.dataset[index]
|
||||
|
@ -17,7 +17,6 @@ from ml_lib.utils.model_io import SavedLightningModels
|
||||
# Datasets
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from utils.project_config import ThisConfig
|
||||
from utils.project_settings import GlobalVar
|
||||
|
||||
|
||||
def prepare_dataloader(config_obj):
|
||||
|
@ -63,16 +63,16 @@ if __name__ == '__main__':
|
||||
type_cluster_eps = 0.1
|
||||
type_cluster_min_pts = 100
|
||||
|
||||
model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0'
|
||||
model_path = Path('output') / 'PN2' / 'PN_597ecab330b04d977cda8a09ae0e5f6e' / 'version_0'
|
||||
|
||||
loaded_model = restore_logger_and_model(model_path)
|
||||
loaded_model.eval()
|
||||
|
||||
transforms = Compose([NormalizeScale(), ])
|
||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||
refresh=True, transform=transforms) # , cluster_type='pc')
|
||||
refresh=True, transform=transforms, cluster_type=None)
|
||||
|
||||
grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||
|
||||
ps.init()
|
||||
|
||||
|
@ -2,7 +2,6 @@ from abc import ABC
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||
|
||||
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
|
||||
from ml_lib.modules.util import LightningBaseModule, F_x
|
||||
|
@ -8,7 +8,6 @@ from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from models._point_net_2 import _PointNetCore
|
||||
|
||||
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
|
||||
from utils.project_settings import GlobalVar
|
||||
|
||||
|
||||
class PointNet2(BaseValMixin,
|
||||
@ -33,7 +32,7 @@ class PointNet2(BaseValMixin,
|
||||
# This is not available with 6-dim cords
|
||||
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
|
||||
RandomTranslate(trans_max_distance),
|
||||
NormalizeScale()
|
||||
# NormalizeScale()
|
||||
# NormalizePositions()
|
||||
]
|
||||
)
|
||||
@ -41,7 +40,6 @@ class PointNet2(BaseValMixin,
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
||||
collate_per_segment=True,
|
||||
transform=transforms,
|
||||
cluster_type=self.params.cluster_type,
|
||||
refresh=self.params.refresh,
|
||||
@ -51,7 +49,7 @@ class PointNet2(BaseValMixin,
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
|
||||
self.n_classes = len(self.dataset.train_dataset.classes)
|
||||
|
||||
# Modules
|
||||
self.lin3 = torch.nn.Linear(128, self.n_classes)
|
||||
|
@ -26,8 +26,9 @@ if __name__ == '__main__':
|
||||
for poly_as_plane in [True, False]:
|
||||
for normals_as_cords in [True, False]:
|
||||
arg_dict.update(main_seed=seed,
|
||||
normals_as_cords=normals_as_cords, poly_as_plane=poly_as_plane)
|
||||
data_normals_as_cords=normals_as_cords,
|
||||
data_poly_as_plane=poly_as_plane
|
||||
)
|
||||
|
||||
config = config.update(arg_dict)
|
||||
|
||||
run_lightning_loop(config)
|
||||
|
@ -15,14 +15,14 @@ import matplotlib.pyplot as plt
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from torch_geometric.data import Data, DataLoader
|
||||
from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale
|
||||
|
||||
from torchcontrib.optim import SWA
|
||||
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.tools import to_one_hot
|
||||
|
||||
from .project_settings import GlobalVar
|
||||
from utils.project_settings import dataSplit
|
||||
|
||||
|
||||
class BaseOptimizerMixin:
|
||||
@ -116,7 +116,7 @@ class BaseValMixin:
|
||||
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
|
||||
y_pred_max = np.argmax(y_pred, axis=1)
|
||||
|
||||
class_names = {val: key for key, val in GlobalVar.classes.items()}
|
||||
class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()}
|
||||
######################################################################################
|
||||
#
|
||||
# F1 SCORE
|
||||
@ -223,21 +223,30 @@ class DatasetMixin:
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
# Data Augmentations or Utility Transformations
|
||||
transforms = Compose(
|
||||
[
|
||||
FixedPoints(8096),
|
||||
|
||||
NormalizeScale()
|
||||
]
|
||||
)
|
||||
test_kwargs = kwargs.copy()
|
||||
test_kwargs.update(transform=transforms)
|
||||
# Dataset
|
||||
dataset = Namespace(
|
||||
**dict(
|
||||
# TRAIN DATASET
|
||||
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
|
||||
train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True,
|
||||
**kwargs),
|
||||
|
||||
|
||||
# VALIDATION DATASET
|
||||
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
|
||||
**kwargs),
|
||||
val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False,
|
||||
**test_kwargs),
|
||||
|
||||
# TEST DATASET
|
||||
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
|
||||
**kwargs),
|
||||
test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False,
|
||||
**test_kwargs),
|
||||
)
|
||||
)
|
||||
return dataset
|
||||
|
@ -1,17 +1,17 @@
|
||||
import numpy as np
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
#import open3d as o3d
|
||||
# import open3d as o3d
|
||||
|
||||
from pyod.models.lof import LOF
|
||||
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from utils.project_settings import Classes
|
||||
from utils.project_settings import classesAll
|
||||
|
||||
|
||||
def polytopes_to_planes(pc):
|
||||
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
|
||||
pc[(pc[:, 6] == float(classesAll.Box)) or (pc[:, 6] == float(classesAll.Polytope)), 6] = float(classesAll.Plane)
|
||||
return pc
|
||||
|
||||
|
||||
|
@ -19,8 +19,7 @@ class DataClass(Namespace):
|
||||
return self.__dict__()[item]
|
||||
|
||||
|
||||
class Classes(DataClass):
|
||||
|
||||
class ClassesALL(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
@ -29,10 +28,11 @@ class Classes(DataClass):
|
||||
Plane = 4 #
|
||||
|
||||
|
||||
class Settings(DataClass):
|
||||
P2G = 'grid'
|
||||
P2P = 'prim'
|
||||
PN2 = 'pc'
|
||||
class ClassesPolyAsPlane(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
Plane = 2 # All SubTypes of Planes
|
||||
|
||||
|
||||
class ClusterTypes(DataClass):
|
||||
@ -40,6 +40,7 @@ class ClusterTypes(DataClass):
|
||||
grid = 'grid'
|
||||
none = ''
|
||||
|
||||
|
||||
class DataSplit(DataClass):
|
||||
# DATA SPLIT OPTIONS
|
||||
train = 'train'
|
||||
@ -47,18 +48,7 @@ class DataSplit(DataClass):
|
||||
test = 'test'
|
||||
predict = 'predict'
|
||||
|
||||
|
||||
class GlobalVar(DataClass):
|
||||
# Variables for plotting
|
||||
PADDING = 0.25
|
||||
DPI = 50
|
||||
|
||||
data_split = DataSplit()
|
||||
|
||||
classes = Classes()
|
||||
|
||||
grid_count = 12
|
||||
|
||||
prim_count = -1
|
||||
|
||||
settings = Settings()
|
||||
classesAll = ClassesALL()
|
||||
classesPolyAsPlane= ClassesPolyAsPlane()
|
||||
clusterTypes = ClusterTypes()
|
||||
dataSplit=DataSplit()
|
Loading…
x
Reference in New Issue
Block a user