DataSet Modifications

This commit is contained in:
Si11ium 2020-07-01 14:15:26 +02:00
parent 6b28519e58
commit 3c1202d5b6
4 changed files with 15 additions and 21 deletions

View File

@ -36,7 +36,7 @@ main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=Fal
# Training # Training
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=25, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="") main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")

View File

@ -111,7 +111,7 @@ class CustomShapeNet(InMemoryDataset):
os.remove(filepath) os.remove(filepath)
try: try:
config_path.unlink() config_path.unlink()
except: except FileNotFoundError:
pass pass
print('Processed Location "Refreshed" (We deleted the Files)') print('Processed Location "Refreshed" (We deleted the Files)')
except FileNotFoundError: except FileNotFoundError:
@ -138,7 +138,7 @@ class CustomShapeNet(InMemoryDataset):
return data, slices return data, slices
def _build_config(self): def _build_config(self):
conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [ conf_dict = {key: str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
'classes', 'refresh', 'transform', 'data', 'slices' 'classes', 'refresh', 'transform', 'data', 'slices'
]} ]}
return conf_dict return conf_dict
@ -194,10 +194,9 @@ class CustomShapeNet(InMemoryDataset):
continue continue
tensor = tensor.unsqueeze(0) tensor = tensor.unsqueeze(0)
if self.poly_as_plane: if self.poly_as_plane:
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0 tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 2.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0 tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 2.0
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0 tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 2.0
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
src[key] = tensor src[key] = tensor
for key, values in src.items(): for key, values in src.items():

View File

@ -5,22 +5,19 @@ import torch
import polyscope as ps import polyscope as ps
import numpy as np import numpy as np
from torch.utils.data import DataLoader
# Dataset and Dataloaders # Dataset and Dataloaders
# ============================================================================= # =============================================================================
# Transforms # Transforms
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip from torch_geometric.transforms import Compose, NormalizeScale
from ml_lib.point_toolset.point_io import BatchToData from ml_lib.point_toolset.point_io import BatchToData
from ml_lib.utils.model_io import SavedLightningModels from ml_lib.utils.model_io import SavedLightningModels
# Datasets # Datasets
from datasets.shapenet import ShapeNetPartSegDataset from datasets.shapenet import ShapeNetPartSegDataset
from models import PointNet2 from models import PointNet2
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, hierarchical_clustering, \ from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \
write_clusters, cluster2Color, cluster_dbscan write_clusters, cluster2Color, cluster_dbscan
from utils.project_settings import GlobalVar, DataClass from utils.project_settings import GlobalVar, DataClass
@ -66,7 +63,7 @@ if __name__ == '__main__':
type_cluster_eps = 0.1 type_cluster_eps = 0.1
type_cluster_min_pts = 100 type_cluster_min_pts = 100
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0' model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0'
loaded_model = restore_logger_and_model(model_path) loaded_model = restore_logger_and_model(model_path)
loaded_model.eval() loaded_model.eval()
@ -75,7 +72,7 @@ if __name__ == '__main__':
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False, test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
refresh=True, transform=transforms) # , cluster_type='pc') refresh=True, transform=transforms) # , cluster_type='pc')
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts) grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
ps.init() ps.init()
@ -90,8 +87,8 @@ if __name__ == '__main__':
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model) pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
# Re-Map Primitive type for 1-hot-encoding. # Re-Map Primitive type for 1-hot-encoding.
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 0.0] = 0.0 # Sphere pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 0.0] = 0.0 # Sphere
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 1.0] = 1.0 # Cylinder pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 1.0] = 1.0 # Cylinder
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 3.0] = 2.0 # Box pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 3.0] = 2.0 # Box
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 4.0] = 2.0 # Polytope pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 4.0] = 2.0 # Polytope
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 6.0] = 2.0 # Plane pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 6.0] = 2.0 # Plane

View File

@ -24,11 +24,9 @@ class Classes(DataClass):
# Object Classes for Point Segmentation # Object Classes for Point Segmentation
Sphere = 0 Sphere = 0
Cylinder = 1 Cylinder = 1
Cone = 2 Box = 2 # All SubTypes of Planes
Box = 3 # All SubTypes of Planes Polytope = 3 #
Polytope = 4 # Plane = 4 #
Torus = 5
Plane = 6 #
class Settings(DataClass): class Settings(DataClass):