DataSet Modifications
This commit is contained in:
parent
6b28519e58
commit
3c1202d5b6
@ -36,7 +36,7 @@ main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=Fal
|
|||||||
# Training
|
# Training
|
||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=25, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
|
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
|
||||||
|
@ -111,7 +111,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
os.remove(filepath)
|
os.remove(filepath)
|
||||||
try:
|
try:
|
||||||
config_path.unlink()
|
config_path.unlink()
|
||||||
except:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
print('Processed Location "Refreshed" (We deleted the Files)')
|
print('Processed Location "Refreshed" (We deleted the Files)')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@ -138,7 +138,7 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
return data, slices
|
return data, slices
|
||||||
|
|
||||||
def _build_config(self):
|
def _build_config(self):
|
||||||
conf_dict = {key:str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
|
conf_dict = {key: str(val) for key, val in self.__dict__.items() if '__' not in key and key not in [
|
||||||
'classes', 'refresh', 'transform', 'data', 'slices'
|
'classes', 'refresh', 'transform', 'data', 'slices'
|
||||||
]}
|
]}
|
||||||
return conf_dict
|
return conf_dict
|
||||||
@ -194,10 +194,9 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
continue
|
continue
|
||||||
tensor = tensor.unsqueeze(0)
|
tensor = tensor.unsqueeze(0)
|
||||||
if self.poly_as_plane:
|
if self.poly_as_plane:
|
||||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 2.0
|
||||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 2.0
|
||||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 2.0
|
||||||
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
|
||||||
src[key] = tensor
|
src[key] = tensor
|
||||||
|
|
||||||
for key, values in src.items():
|
for key, values in src.items():
|
||||||
|
@ -5,22 +5,19 @@ import torch
|
|||||||
import polyscope as ps
|
import polyscope as ps
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Dataset and Dataloaders
|
# Dataset and Dataloaders
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
from torch_geometric.transforms import Compose, NormalizeScale
|
||||||
|
|
||||||
from ml_lib.point_toolset.point_io import BatchToData
|
from ml_lib.point_toolset.point_io import BatchToData
|
||||||
from ml_lib.utils.model_io import SavedLightningModels
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
|
|
||||||
|
|
||||||
# Datasets
|
# Datasets
|
||||||
from datasets.shapenet import ShapeNetPartSegDataset
|
from datasets.shapenet import ShapeNetPartSegDataset
|
||||||
from models import PointNet2
|
from models import PointNet2
|
||||||
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, hierarchical_clustering, \
|
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \
|
||||||
write_clusters, cluster2Color, cluster_dbscan
|
write_clusters, cluster2Color, cluster_dbscan
|
||||||
from utils.project_settings import GlobalVar, DataClass
|
from utils.project_settings import GlobalVar, DataClass
|
||||||
|
|
||||||
@ -66,7 +63,7 @@ if __name__ == '__main__':
|
|||||||
type_cluster_eps = 0.1
|
type_cluster_eps = 0.1
|
||||||
type_cluster_min_pts = 100
|
type_cluster_min_pts = 100
|
||||||
|
|
||||||
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
|
model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0'
|
||||||
|
|
||||||
loaded_model = restore_logger_and_model(model_path)
|
loaded_model = restore_logger_and_model(model_path)
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
@ -75,7 +72,7 @@ if __name__ == '__main__':
|
|||||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||||
refresh=True, transform=transforms) # , cluster_type='pc')
|
refresh=True, transform=transforms) # , cluster_type='pc')
|
||||||
|
|
||||||
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||||
|
|
||||||
ps.init()
|
ps.init()
|
||||||
|
|
||||||
@ -90,8 +87,8 @@ if __name__ == '__main__':
|
|||||||
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
|
pc_with_prim_type = predict_prim_type(grid_cluster_pc, loaded_model)
|
||||||
|
|
||||||
# Re-Map Primitive type for 1-hot-encoding.
|
# Re-Map Primitive type for 1-hot-encoding.
|
||||||
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 0.0] = 0.0 # Sphere
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 0.0] = 0.0 # Sphere
|
||||||
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 1.0] = 1.0 # Cylinder
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 1.0] = 1.0 # Cylinder
|
||||||
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 3.0] = 2.0 # Box
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 3.0] = 2.0 # Box
|
||||||
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 4.0] = 2.0 # Polytope
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 4.0] = 2.0 # Polytope
|
||||||
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 6.0] = 2.0 # Plane
|
pc_with_prim_type[:, 6][pc_with_prim_type[:, 6] == 6.0] = 2.0 # Plane
|
||||||
|
@ -24,11 +24,9 @@ class Classes(DataClass):
|
|||||||
# Object Classes for Point Segmentation
|
# Object Classes for Point Segmentation
|
||||||
Sphere = 0
|
Sphere = 0
|
||||||
Cylinder = 1
|
Cylinder = 1
|
||||||
Cone = 2
|
Box = 2 # All SubTypes of Planes
|
||||||
Box = 3 # All SubTypes of Planes
|
Polytope = 3 #
|
||||||
Polytope = 4 #
|
Plane = 4 #
|
||||||
Torus = 5
|
|
||||||
Plane = 6 #
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(DataClass):
|
class Settings(DataClass):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user