Classes Fixed abnd debugging

This commit is contained in:
Si11ium 2020-07-03 14:40:28 +02:00
parent e9d0591b11
commit 5353220890
10 changed files with 66 additions and 59 deletions

View File

@ -25,7 +25,7 @@ main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="") main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
main_arg_parser.add_argument("--data_cluster_type", type=str, default='prim', help="") main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="")

View File

@ -13,7 +13,7 @@ import torch
from torch_geometric.data import InMemoryDataset from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data from torch_geometric.data import Data
from utils.project_settings import Classes, DataSplit, ClusterTypes from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes
def save_names(name_list, path): def save_names(name_list, path):
@ -34,11 +34,11 @@ class CustomShapeNet(InMemoryDataset):
@property @property
def modes(self): def modes(self):
return {key: val for val, key in DataSplit().items()} return {key: val for val, key in dataSplit.items()}
@property @property
def cluster_types(self): def cluster_types(self):
return {key: val for val, key in ClusterTypes().items()} return {key: val for val, key in clusterTypes.items()}
@property @property
def raw_dir(self): def raw_dir(self):
@ -62,8 +62,8 @@ class CustomShapeNet(InMemoryDataset):
# Set the Dataset Parameters # Set the Dataset Parameters
self.cluster_type = cluster_type if cluster_type else 'pc' self.cluster_type = cluster_type if cluster_type else 'pc'
self.classes = Classes()
self.poly_as_plane = poly_as_plane self.poly_as_plane = poly_as_plane
self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane
self.collate_per_segment = collate_per_segment self.collate_per_segment = collate_per_segment
self.mode = mode self.mode = mode
self.refresh = refresh self.refresh = refresh
@ -92,10 +92,10 @@ class CustomShapeNet(InMemoryDataset):
@property @property
def num_classes(self): def num_classes(self):
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2) return len(self.categories)
@property @property
def class_map_all(self): def _class_map_all(self):
return {0: 0, return {0: 0,
1: 1, 1: 1,
2: None, 2: None,
@ -107,7 +107,7 @@ class CustomShapeNet(InMemoryDataset):
} }
@property @property
def class_map_poly_as_plane(self): def _class_map_poly_as_plane(self):
return {0: 0, return {0: 0,
1: 1, 1: 1,
2: None, 2: None,
@ -118,11 +118,15 @@ class CustomShapeNet(InMemoryDataset):
7: None 7: None
} }
@property
def class_remap(self):
return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane
def _load_dataset(self): def _load_dataset(self):
data, slices = None, None data, slices = None, None
filepath = self.processed_paths[0] filepath = self.processed_paths[0]
config_path = Path(filepath).parent / f'{self.mode}_params.ini' config_path = Path(filepath).parent / f'{self.mode}_params.ini'
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict: if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict:
with config_path.open('rb') as f: with config_path.open('rb') as f:
config = pickle.load(f) config = pickle.load(f)
if config == self._build_config(): if config == self._build_config():
@ -155,7 +159,7 @@ class CustomShapeNet(InMemoryDataset):
break break
self.process() self.process()
continue continue
if not self.mode == DataSplit().predict: if not self.mode == dataSplit.predict:
config = self._build_config() config = self._build_config()
with config_path.open('wb') as f: with config_path.open('wb') as f:
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL) pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
@ -178,7 +182,6 @@ class CustomShapeNet(InMemoryDataset):
datasets = defaultdict(list) datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode path_to_clouds = self.raw_dir / self.mode
found_clouds = list(path_to_clouds.glob('*.xyz')) found_clouds = list(path_to_clouds.glob('*.xyz'))
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
if len(found_clouds): if len(found_clouds):
for pointcloud in tqdm(found_clouds): for pointcloud in tqdm(found_clouds):
if self.cluster_type not in pointcloud.name: if self.cluster_type not in pointcloud.name:
@ -196,21 +199,28 @@ class CustomShapeNet(InMemoryDataset):
raise ValueError('Check the Input!!!!!!') raise ValueError('Check the Input!!!!!!')
# Expand the values from the csv by fake labels if non are provided. # Expand the values from the csv by fake labels if non are provided.
vals = vals + [0] * (8 - len(vals)) vals = vals + [0] * (8 - len(vals))
vals[-2] = float(class_map[int(vals[-2])]) vals[-2] = float(self.class_remap[int(vals[-2])])
src[vals[-1]].append(vals) src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict # Switch from un-pickable Defaultdict to Standard Dict
src = dict(src) src = dict(src)
# Transform the Dict[List] to Dict[torch.Tensor] # Transform the Dict[List] to Dict[torch.Tensor]
for key, values in src.items(): for key, values in list(src.items()):
src[key] = torch.tensor(values, dtype=torch.double).squeeze() src[key] = torch.tensor(values, dtype=torch.double).squeeze()
if src[key].ndim == 2:
pass
else:
del src[key]
# Screw the Sorting and make it a FullCloud rather than a seperated # Screw the Sorting and make it a FullCloud rather than a seperated
if not self.collate_per_segment: if not self.collate_per_segment:
try:
src = dict( src = dict(
all=torch.cat(tuple(src.values())) all=torch.cat(tuple(src.values()))
) )
except RuntimeError:
print('debugg')
# Transform Box and Polytope to Plane if poly_as_plane is set # Transform Box and Polytope to Plane if poly_as_plane is set
for key, tensor in src.items(): for key, tensor in src.items():
@ -274,6 +284,7 @@ class ShapeNetPartSegDataset(Dataset):
kwargs.update(dict(root_dir=root_dir, mode=self.mode)) kwargs.update(dict(root_dir=root_dir, mode=self.mode))
# self.npoints = npoints # self.npoints = npoints
self.dataset = CustomShapeNet(**kwargs) self.dataset = CustomShapeNet(**kwargs)
self.classes = self.dataset.classes
def __getitem__(self, index): def __getitem__(self, index):
data = self.dataset[index] data = self.dataset[index]

View File

@ -17,7 +17,6 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets # Datasets
from datasets.shapenet import ShapeNetPartSegDataset from datasets.shapenet import ShapeNetPartSegDataset
from utils.project_config import ThisConfig from utils.project_config import ThisConfig
from utils.project_settings import GlobalVar
def prepare_dataloader(config_obj): def prepare_dataloader(config_obj):

View File

@ -63,16 +63,16 @@ if __name__ == '__main__':
type_cluster_eps = 0.1 type_cluster_eps = 0.1
type_cluster_min_pts = 100 type_cluster_min_pts = 100
model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0' model_path = Path('output') / 'PN2' / 'PN_597ecab330b04d977cda8a09ae0e5f6e' / 'version_0'
loaded_model = restore_logger_and_model(model_path) loaded_model = restore_logger_and_model(model_path)
loaded_model.eval() loaded_model.eval()
transforms = Compose([NormalizeScale(), ]) transforms = Compose([NormalizeScale(), ])
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms) # , cluster_type='pc') refresh=True, transform=transforms, cluster_type=None)
grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts) grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
ps.init() ps.init()

View File

@ -2,7 +2,6 @@ from abc import ABC
import torch import torch
from torch import nn from torch import nn
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
from ml_lib.modules.util import LightningBaseModule, F_x from ml_lib.modules.util import LightningBaseModule, F_x

View File

@ -8,7 +8,6 @@ from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_settings import GlobalVar
class PointNet2(BaseValMixin, class PointNet2(BaseValMixin,
@ -33,7 +32,7 @@ class PointNet2(BaseValMixin,
# This is not available with 6-dim cords # This is not available with 6-dim cords
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2), # RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
RandomTranslate(trans_max_distance), RandomTranslate(trans_max_distance),
NormalizeScale() # NormalizeScale()
# NormalizePositions() # NormalizePositions()
] ]
) )
@ -41,7 +40,6 @@ class PointNet2(BaseValMixin,
# Dataset # Dataset
# ============================================================================= # =============================================================================
self.dataset = self.build_dataset(ShapeNetPartSegDataset, self.dataset = self.build_dataset(ShapeNetPartSegDataset,
collate_per_segment=True,
transform=transforms, transform=transforms,
cluster_type=self.params.cluster_type, cluster_type=self.params.cluster_type,
refresh=self.params.refresh, refresh=self.params.refresh,
@ -51,7 +49,7 @@ class PointNet2(BaseValMixin,
# Model Paramters # Model Paramters
# ============================================================================= # =============================================================================
# Additional parameters # Additional parameters
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2) self.n_classes = len(self.dataset.train_dataset.classes)
# Modules # Modules
self.lin3 = torch.nn.Linear(128, self.n_classes) self.lin3 = torch.nn.Linear(128, self.n_classes)

View File

@ -26,8 +26,9 @@ if __name__ == '__main__':
for poly_as_plane in [True, False]: for poly_as_plane in [True, False]:
for normals_as_cords in [True, False]: for normals_as_cords in [True, False]:
arg_dict.update(main_seed=seed, arg_dict.update(main_seed=seed,
normals_as_cords=normals_as_cords, poly_as_plane=poly_as_plane) data_normals_as_cords=normals_as_cords,
data_poly_as_plane=poly_as_plane
)
config = config.update(arg_dict) config = config.update(arg_dict)
run_lightning_loop(config) run_lightning_loop(config)

View File

@ -15,14 +15,14 @@ import matplotlib.pyplot as plt
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from torch_geometric.data import Data, DataLoader from torch_geometric.data import Data, DataLoader
from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale
from torchcontrib.optim import SWA from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot from ml_lib.utils.tools import to_one_hot
from utils.project_settings import dataSplit
from .project_settings import GlobalVar
class BaseOptimizerMixin: class BaseOptimizerMixin:
@ -116,7 +116,7 @@ class BaseValMixin:
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy() y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1) y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in GlobalVar.classes.items()} class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()}
###################################################################################### ######################################################################################
# #
# F1 SCORE # F1 SCORE
@ -223,21 +223,30 @@ class DatasetMixin:
# Dataset # Dataset
# ============================================================================= # =============================================================================
# Data Augmentations or Utility Transformations # Data Augmentations or Utility Transformations
transforms = Compose(
[
FixedPoints(8096),
NormalizeScale()
]
)
test_kwargs = kwargs.copy()
test_kwargs.update(transform=transforms)
# Dataset # Dataset
dataset = Namespace( dataset = Namespace(
**dict( **dict(
# TRAIN DATASET # TRAIN DATASET
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train, train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True,
**kwargs), **kwargs),
# VALIDATION DATASET # VALIDATION DATASET
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel, val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False,
**kwargs), **test_kwargs),
# TEST DATASET # TEST DATASET
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict, test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False,
**kwargs), **test_kwargs),
) )
) )
return dataset return dataset

View File

@ -1,17 +1,17 @@
import numpy as np import numpy as np
from sklearn.cluster import DBSCAN from sklearn.cluster import DBSCAN
#import open3d as o3d # import open3d as o3d
from pyod.models.lof import LOF from pyod.models.lof import LOF
from torch_geometric.data import Data from torch_geometric.data import Data
from utils.project_settings import Classes from utils.project_settings import classesAll
def polytopes_to_planes(pc): def polytopes_to_planes(pc):
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane) pc[(pc[:, 6] == float(classesAll.Box)) or (pc[:, 6] == float(classesAll.Polytope)), 6] = float(classesAll.Plane)
return pc return pc

View File

@ -19,8 +19,7 @@ class DataClass(Namespace):
return self.__dict__()[item] return self.__dict__()[item]
class Classes(DataClass): class ClassesALL(DataClass):
# Object Classes for Point Segmentation # Object Classes for Point Segmentation
Sphere = 0 Sphere = 0
Cylinder = 1 Cylinder = 1
@ -29,10 +28,11 @@ class Classes(DataClass):
Plane = 4 # Plane = 4 #
class Settings(DataClass): class ClassesPolyAsPlane(DataClass):
P2G = 'grid' # Object Classes for Point Segmentation
P2P = 'prim' Sphere = 0
PN2 = 'pc' Cylinder = 1
Plane = 2 # All SubTypes of Planes
class ClusterTypes(DataClass): class ClusterTypes(DataClass):
@ -40,6 +40,7 @@ class ClusterTypes(DataClass):
grid = 'grid' grid = 'grid'
none = '' none = ''
class DataSplit(DataClass): class DataSplit(DataClass):
# DATA SPLIT OPTIONS # DATA SPLIT OPTIONS
train = 'train' train = 'train'
@ -47,18 +48,7 @@ class DataSplit(DataClass):
test = 'test' test = 'test'
predict = 'predict' predict = 'predict'
classesAll = ClassesALL()
class GlobalVar(DataClass): classesPolyAsPlane= ClassesPolyAsPlane()
# Variables for plotting clusterTypes = ClusterTypes()
PADDING = 0.25 dataSplit=DataSplit()
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()