explicit model argument

This commit is contained in:
Si11ium 2020-06-19 13:35:37 +02:00
parent fe2bc131df
commit 49b373a8a1
3 changed files with 11 additions and 7 deletions

View File

@ -1,4 +1,5 @@
from pathlib import Path
from warnings import warn
import numpy as np
@ -55,13 +56,14 @@ class CustomShapeNet(InMemoryDataset):
def processed_file_names(self):
return [f'{self.mode}.pt']
def download(self):
def __download(self):
dir_count = len([name for name in os.listdir(self.raw_dir) if os.path.isdir(os.path.join(self.raw_dir, name))])
if dir_count:
print(f'{dir_count} folders have been found....')
return dir_count
raise IOError("No raw pointclouds have been found.")
warn(ResourceWarning("No raw pointclouds have been found. Was this intentional?"))
return dir_count
@property
def num_classes(self):

View File

@ -23,8 +23,10 @@ class PointNet2(BaseValMixin,
# Dataset
# =============================================================================
self.dataset = self.build_dataset(ShapeNetPartSegDataset, collate_per_segment=True,
npoints=self.params.npoints)
self.dataset = self.build_dataset(ShapeNetPartSegDataset,
collate_per_segment=True,
npoints=self.params.npoints
)
# Model Paramters
# =============================================================================

View File

@ -110,7 +110,7 @@ def cluster_per_column(pc, column):
def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_cluster=-1):
if cluster_dims[0] is 1 and cluster_dims[1] is 1 and cluster_dims[2] is 1:
if cluster_dims[0] == 1 and cluster_dims[1] == 1 and cluster_dims[2] == 1:
print("no need to cluster.")
return [farthest_point_sampling(data, max_points_per_cluster)]
@ -141,7 +141,7 @@ def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_
final_clusters = []
for key, cluster in clusters.items():
c = np.vstack(cluster)
if c.shape[0] < min_points_per_cluster and -1 is not min_points_per_cluster:
if c.shape[0] < min_points_per_cluster and -1 != min_points_per_cluster:
continue
if max_points_per_cluster is not -1:
@ -171,7 +171,7 @@ def cluster_dbscan(data, selected_indices, eps, min_samples=5, metric='euclidean
clusters = {}
for idx, l in enumerate(labels):
if l is -1:
if l == -1:
continue
clusters.setdefault(str(l), []).append(data[idx, :])