Classes Fixed abnd debugging

This commit is contained in:
Si11ium 2020-07-03 14:40:28 +02:00
parent e9d0591b11
commit 5353220890
10 changed files with 66 additions and 59 deletions

View File

@ -25,7 +25,7 @@ main_arg_parser.add_argument("--data_npoints", type=int, default=1024, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_refresh", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_dataset_type", type=str, default='ShapeNetPartSegDataset', help="")
main_arg_parser.add_argument("--data_cluster_type", type=str, default='prim', help="")
main_arg_parser.add_argument("--data_cluster_type", type=str, default='grid', help="")
main_arg_parser.add_argument("--data_normals_as_cords", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=False, help="")

View File

@ -13,7 +13,7 @@ import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
from utils.project_settings import Classes, DataSplit, ClusterTypes
from utils.project_settings import classesAll, classesPolyAsPlane, dataSplit, clusterTypes
def save_names(name_list, path):
@ -34,11 +34,11 @@ class CustomShapeNet(InMemoryDataset):
@property
def modes(self):
return {key: val for val, key in DataSplit().items()}
return {key: val for val, key in dataSplit.items()}
@property
def cluster_types(self):
return {key: val for val, key in ClusterTypes().items()}
return {key: val for val, key in clusterTypes.items()}
@property
def raw_dir(self):
@ -62,8 +62,8 @@ class CustomShapeNet(InMemoryDataset):
# Set the Dataset Parameters
self.cluster_type = cluster_type if cluster_type else 'pc'
self.classes = Classes()
self.poly_as_plane = poly_as_plane
self.classes = classesAll if not self.poly_as_plane else classesPolyAsPlane
self.collate_per_segment = collate_per_segment
self.mode = mode
self.refresh = refresh
@ -92,10 +92,10 @@ class CustomShapeNet(InMemoryDataset):
@property
def num_classes(self):
return len(self.categories) if self.poly_as_plane else (len(self.categories) - 2)
return len(self.categories)
@property
def class_map_all(self):
def _class_map_all(self):
return {0: 0,
1: 1,
2: None,
@ -107,7 +107,7 @@ class CustomShapeNet(InMemoryDataset):
}
@property
def class_map_poly_as_plane(self):
def _class_map_poly_as_plane(self):
return {0: 0,
1: 1,
2: None,
@ -118,11 +118,15 @@ class CustomShapeNet(InMemoryDataset):
7: None
}
@property
def class_remap(self):
return self._class_map_all if not self.poly_as_plane else self._class_map_poly_as_plane
def _load_dataset(self):
data, slices = None, None
filepath = self.processed_paths[0]
config_path = Path(filepath).parent / f'{self.mode}_params.ini'
if config_path.exists() and not self.refresh and not self.mode == DataSplit().predict:
if config_path.exists() and not self.refresh and not self.mode == dataSplit.predict:
with config_path.open('rb') as f:
config = pickle.load(f)
if config == self._build_config():
@ -155,7 +159,7 @@ class CustomShapeNet(InMemoryDataset):
break
self.process()
continue
if not self.mode == DataSplit().predict:
if not self.mode == dataSplit.predict:
config = self._build_config()
with config_path.open('wb') as f:
pickle.dump(config, f, pickle.HIGHEST_PROTOCOL)
@ -178,7 +182,6 @@ class CustomShapeNet(InMemoryDataset):
datasets = defaultdict(list)
path_to_clouds = self.raw_dir / self.mode
found_clouds = list(path_to_clouds.glob('*.xyz'))
class_map = self.class_map_all if not self.poly_as_plane else self.class_map_poly_as_plane
if len(found_clouds):
for pointcloud in tqdm(found_clouds):
if self.cluster_type not in pointcloud.name:
@ -196,21 +199,28 @@ class CustomShapeNet(InMemoryDataset):
raise ValueError('Check the Input!!!!!!')
# Expand the values from the csv by fake labels if non are provided.
vals = vals + [0] * (8 - len(vals))
vals[-2] = float(class_map[int(vals[-2])])
vals[-2] = float(self.class_remap[int(vals[-2])])
src[vals[-1]].append(vals)
# Switch from un-pickable Defaultdict to Standard Dict
src = dict(src)
# Transform the Dict[List] to Dict[torch.Tensor]
for key, values in src.items():
for key, values in list(src.items()):
src[key] = torch.tensor(values, dtype=torch.double).squeeze()
if src[key].ndim == 2:
pass
else:
del src[key]
# Screw the Sorting and make it a FullCloud rather than a seperated
if not self.collate_per_segment:
src = dict(
all=torch.cat(tuple(src.values()))
)
try:
src = dict(
all=torch.cat(tuple(src.values()))
)
except RuntimeError:
print('debugg')
# Transform Box and Polytope to Plane if poly_as_plane is set
for key, tensor in src.items():
@ -274,6 +284,7 @@ class ShapeNetPartSegDataset(Dataset):
kwargs.update(dict(root_dir=root_dir, mode=self.mode))
# self.npoints = npoints
self.dataset = CustomShapeNet(**kwargs)
self.classes = self.dataset.classes
def __getitem__(self, index):
data = self.dataset[index]

View File

@ -17,7 +17,6 @@ from ml_lib.utils.model_io import SavedLightningModels
# Datasets
from datasets.shapenet import ShapeNetPartSegDataset
from utils.project_config import ThisConfig
from utils.project_settings import GlobalVar
def prepare_dataloader(config_obj):

View File

@ -63,16 +63,16 @@ if __name__ == '__main__':
type_cluster_eps = 0.1
type_cluster_min_pts = 100
model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0'
model_path = Path('output') / 'PN2' / 'PN_597ecab330b04d977cda8a09ae0e5f6e' / 'version_0'
loaded_model = restore_logger_and_model(model_path)
loaded_model.eval()
transforms = Compose([NormalizeScale(), ])
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms) # , cluster_type='pc')
refresh=True, transform=transforms, cluster_type=None)
grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
ps.init()

View File

@ -2,7 +2,6 @@ from abc import ABC
import torch
from torch import nn
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
from ml_lib.modules.geometric_blocks import SAModule, GlobalSAModule, MLP, FPModule
from ml_lib.modules.util import LightningBaseModule, F_x

View File

@ -8,7 +8,6 @@ from datasets.shapenet import ShapeNetPartSegDataset
from models._point_net_2 import _PointNetCore
from utils.module_mixins import BaseValMixin, BaseTrainMixin, BaseOptimizerMixin, BaseDataloadersMixin, DatasetMixin
from utils.project_settings import GlobalVar
class PointNet2(BaseValMixin,
@ -33,7 +32,7 @@ class PointNet2(BaseValMixin,
# This is not available with 6-dim cords
# RandomRotate(rot_max_angle, 0), RandomRotate(rot_max_angle, 1), RandomRotate(rot_max_angle, 2),
RandomTranslate(trans_max_distance),
NormalizeScale()
# NormalizeScale()
# NormalizePositions()
]
)
@ -41,7 +40,6 @@ class PointNet2(BaseValMixin,
# Dataset
# =============================================================================
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
collate_per_segment=True,
transform=transforms,
cluster_type=self.params.cluster_type,
refresh=self.params.refresh,
@ -51,7 +49,7 @@ class PointNet2(BaseValMixin,
# Model Paramters
# =============================================================================
# Additional parameters
self.n_classes = len(GlobalVar.classes) if not self.params.poly_as_plane else (len(GlobalVar.classes) - 2)
self.n_classes = len(self.dataset.train_dataset.classes)
# Modules
self.lin3 = torch.nn.Linear(128, self.n_classes)

View File

@ -26,8 +26,9 @@ if __name__ == '__main__':
for poly_as_plane in [True, False]:
for normals_as_cords in [True, False]:
arg_dict.update(main_seed=seed,
normals_as_cords=normals_as_cords, poly_as_plane=poly_as_plane)
data_normals_as_cords=normals_as_cords,
data_poly_as_plane=poly_as_plane
)
config = config.update(arg_dict)
run_lightning_loop(config)

View File

@ -15,14 +15,14 @@ import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam
from torch_geometric.data import Data, DataLoader
from torch_geometric.transforms import Compose, FixedPoints, NormalizeScale
from torchcontrib.optim import SWA
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot
from .project_settings import GlobalVar
from utils.project_settings import dataSplit
class BaseOptimizerMixin:
@ -116,7 +116,7 @@ class BaseValMixin:
y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
y_pred_max = np.argmax(y_pred, axis=1)
class_names = {val: key for key, val in GlobalVar.classes.items()}
class_names = {val: key for key, val in self.dataset.test_dataset.classes.items()}
######################################################################################
#
# F1 SCORE
@ -223,21 +223,30 @@ class DatasetMixin:
# Dataset
# =============================================================================
# Data Augmentations or Utility Transformations
transforms = Compose(
[
FixedPoints(8096),
NormalizeScale()
]
)
test_kwargs = kwargs.copy()
test_kwargs.update(transform=transforms)
# Dataset
dataset = Namespace(
**dict(
# TRAIN DATASET
train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
train_dataset=dataset_class(self.params.root, mode=dataSplit.train, collate_per_segment=True,
**kwargs),
# VALIDATION DATASET
val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
**kwargs),
val_dataset=dataset_class(self.params.root, mode=dataSplit.devel, collate_per_segment=False,
**test_kwargs),
# TEST DATASET
test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
**kwargs),
test_dataset=dataset_class(self.params.root, mode=dataSplit.predict, collate_per_segment=False,
**test_kwargs),
)
)
return dataset

View File

@ -1,17 +1,17 @@
import numpy as np
from sklearn.cluster import DBSCAN
#import open3d as o3d
# import open3d as o3d
from pyod.models.lof import LOF
from torch_geometric.data import Data
from utils.project_settings import Classes
from utils.project_settings import classesAll
def polytopes_to_planes(pc):
pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
pc[(pc[:, 6] == float(classesAll.Box)) or (pc[:, 6] == float(classesAll.Polytope)), 6] = float(classesAll.Plane)
return pc

View File

@ -19,8 +19,7 @@ class DataClass(Namespace):
return self.__dict__()[item]
class Classes(DataClass):
class ClassesALL(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
@ -29,10 +28,11 @@ class Classes(DataClass):
Plane = 4 #
class Settings(DataClass):
P2G = 'grid'
P2P = 'prim'
PN2 = 'pc'
class ClassesPolyAsPlane(DataClass):
# Object Classes for Point Segmentation
Sphere = 0
Cylinder = 1
Plane = 2 # All SubTypes of Planes
class ClusterTypes(DataClass):
@ -40,6 +40,7 @@ class ClusterTypes(DataClass):
grid = 'grid'
none = ''
class DataSplit(DataClass):
# DATA SPLIT OPTIONS
train = 'train'
@ -47,18 +48,7 @@ class DataSplit(DataClass):
test = 'test'
predict = 'predict'
class GlobalVar(DataClass):
# Variables for plotting
PADDING = 0.25
DPI = 50
data_split = DataSplit()
classes = Classes()
grid_count = 12
prim_count = -1
settings = Settings()
classesAll = ClassesALL()
classesPolyAsPlane= ClassesPolyAsPlane()
clusterTypes = ClusterTypes()
dataSplit=DataSplit()