pointnet2 working - TODO: Eval!

This commit is contained in:
Si11ium
2020-05-26 21:44:56 +02:00
parent 4b4051c045
commit 77ea043907
5 changed files with 138 additions and 73 deletions

View File

@@ -6,19 +6,23 @@ class FarthestpointSampling():
def __init__(self, K):
self.k = K
@staticmethod
def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
def __call__(self, pts, *args, **kwargs):
if pts.shape[0] < self.k:
return pts
def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
farthest_pts = np.zeros((self.k, pts.shape[1]))
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
farthest_pts[0] = pts[np.random.randint(len(pts))]
distances = calc_distances(farthest_pts[0], pts)
distances = self.calc_distances(farthest_pts[0], pts)
for i in range(1, self.k):
farthest_pts[i] = pts[np.argmax(distances)]
distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))
farthest_pts_idx[i] = np.argmax(distances)
farthest_pts[i] = pts[farthest_pts_idx[i]]
return farthest_pts
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
return farthest_pts_idx