DataSet Modifications
This commit is contained in:
parent
6b28519e58
commit
3c1202d5b6
@ -36,7 +36,7 @@ main_arg_parser.add_argument("--data_poly_as_plane", type=strtobool, default=Fal
|
||||
# Training
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=25, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
|
||||
|
@ -111,7 +111,7 @@ class CustomShapeNet(InMemoryDataset):
|
||||
os.remove(filepath)
|
||||
try:
|
||||
config_path.unlink()
|
||||
except:
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
print('Processed Location "Refreshed" (We deleted the Files)')
|
||||
except FileNotFoundError:
|
||||
@ -194,10 +194,9 @@ class CustomShapeNet(InMemoryDataset):
|
||||
continue
|
||||
tensor = tensor.unsqueeze(0)
|
||||
if self.poly_as_plane:
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 4.0
|
||||
tensor[:, -2][tensor[:, -2] == self.classes.Torus] = 3.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Plane)] = 2.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Box)] = 2.0
|
||||
tensor[:, -2][tensor[:, -2] == float(self.classes.Polytope)] = 2.0
|
||||
src[key] = tensor
|
||||
|
||||
for key, values in src.items():
|
||||
|
@ -5,22 +5,19 @@ import torch
|
||||
import polyscope as ps
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Dataset and Dataloaders
|
||||
# =============================================================================
|
||||
|
||||
# Transforms
|
||||
from torch_geometric.transforms import Compose, NormalizeScale, RandomFlip
|
||||
from torch_geometric.transforms import Compose, NormalizeScale
|
||||
|
||||
from ml_lib.point_toolset.point_io import BatchToData
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
|
||||
|
||||
# Datasets
|
||||
from datasets.shapenet import ShapeNetPartSegDataset
|
||||
from models import PointNet2
|
||||
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, hierarchical_clustering, \
|
||||
from utils.pointcloud import cluster_cubes, append_onehotencoded_type, label2color, \
|
||||
write_clusters, cluster2Color, cluster_dbscan
|
||||
from utils.project_settings import GlobalVar, DataClass
|
||||
|
||||
@ -66,7 +63,7 @@ if __name__ == '__main__':
|
||||
type_cluster_eps = 0.1
|
||||
type_cluster_min_pts = 100
|
||||
|
||||
model_path = Path('output') / 'PN2' / 'PN_9843bf499399786cfd58fe79fa1b3db8' / 'version_0'
|
||||
model_path = Path('output') / 'PN2' / 'PN_f0d6bc0b9bf95a7e64f31a7df3c820d0' / 'version_0'
|
||||
|
||||
loaded_model = restore_logger_and_model(model_path)
|
||||
loaded_model.eval()
|
||||
@ -75,7 +72,7 @@ if __name__ == '__main__':
|
||||
test_dataset = ShapeNetPartSegDataset('data', mode=GlobalVar.data_split.predict, collate_per_segment=False,
|
||||
refresh=True, transform=transforms) # , cluster_type='pc')
|
||||
|
||||
grid_clusters = cluster_cubes(test_dataset[0], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||
grid_clusters = cluster_cubes(test_dataset[1], grid_clusters, max_points_per_cluster=grid_cluster_max_pts)
|
||||
|
||||
ps.init()
|
||||
|
||||
|
@ -24,11 +24,9 @@ class Classes(DataClass):
|
||||
# Object Classes for Point Segmentation
|
||||
Sphere = 0
|
||||
Cylinder = 1
|
||||
Cone = 2
|
||||
Box = 3 # All SubTypes of Planes
|
||||
Polytope = 4 #
|
||||
Torus = 5
|
||||
Plane = 6 #
|
||||
Box = 2 # All SubTypes of Planes
|
||||
Polytope = 3 #
|
||||
Plane = 4 #
|
||||
|
||||
|
||||
class Settings(DataClass):
|
||||
|
Loading…
x
Reference in New Issue
Block a user