explicit model argument
This commit is contained in:
parent
fe2bc131df
commit
49b373a8a1
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -55,13 +56,14 @@ class CustomShapeNet(InMemoryDataset):
|
|||||||
def processed_file_names(self):
|
def processed_file_names(self):
|
||||||
return [f'{self.mode}.pt']
|
return [f'{self.mode}.pt']
|
||||||
|
|
||||||
def download(self):
|
def __download(self):
|
||||||
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
|
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
|
||||||
|
|
||||||
if dir_count:
|
if dir_count:
|
||||||
print(f'{dir_count} folders have been found....')
|
print(f'{dir_count} folders have been found....')
|
||||||
return dir_count
|
return dir_count
|
||||||
raise IOError("No raw pointclouds have been found.")
|
warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
|
||||||
|
return dir_count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
|
@ -23,8 +23,10 @@ class PointNet2(BaseValMixin,
|
|||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset(ShapeNetPartSegDataset, collate_per_segment=True,
|
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
|
||||||
npoints=self.params.npoints)
|
collate_per_segment=True,
|
||||||
|
npoints=self.params.npoints
|
||||||
|
)
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
@ -110,7 +110,7 @@ def cluster_per_column(pc, column):
|
|||||||
|
|
||||||
|
|
||||||
def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_cluster=-1):
|
def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_cluster=-1):
|
||||||
if cluster_dims[0] is 1 and cluster_dims[1] is 1 and cluster_dims[2] is 1:
|
if cluster_dims[0] == 1 and cluster_dims[1] == 1 and cluster_dims[2] == 1:
|
||||||
print("no need to cluster.")
|
print("no need to cluster.")
|
||||||
return [farthest_point_sampling(data, max_points_per_cluster)]
|
return [farthest_point_sampling(data, max_points_per_cluster)]
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_
|
|||||||
final_clusters = []
|
final_clusters = []
|
||||||
for key, cluster in clusters.items():
|
for key, cluster in clusters.items():
|
||||||
c = np.vstack(cluster)
|
c = np.vstack(cluster)
|
||||||
if c.shape[0] < min_points_per_cluster and -1 is not min_points_per_cluster:
|
if c.shape[0] < min_points_per_cluster and -1 != min_points_per_cluster:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if max_points_per_cluster is not -1:
|
if max_points_per_cluster is not -1:
|
||||||
@ -171,7 +171,7 @@ def cluster_dbscan(data, selected_indices, eps, min_samples=5, metric='euclidean
|
|||||||
|
|
||||||
clusters = {}
|
clusters = {}
|
||||||
for idx, l in enumerate(labels):
|
for idx, l in enumerate(labels):
|
||||||
if l is -1:
|
if l == -1:
|
||||||
continue
|
continue
|
||||||
clusters.setdefault(str(l), []).append(data[idx, :])
|
clusters.setdefault(str(l), []).append(data[idx, :])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user