56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
from abc import ABC
|
|
|
|
import numpy as np
|
|
|
|
|
|
class _Sampler(ABC):
|
|
|
|
def __init__(self, K, **kwargs):
|
|
self.k = K
|
|
self.kwargs = kwargs
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class RandomSampling(_Sampler):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(RandomSampling, self).__init__(*args, **kwargs)
|
|
|
|
def __call__(self, pts, *args, **kwargs):
|
|
if pts.shape[0] < self.k:
|
|
return pts
|
|
|
|
else:
|
|
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False)
|
|
return rnd_indexs
|
|
|
|
|
|
class FarthestpointSampling(_Sampler):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(FarthestpointSampling, self).__init__(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def calc_distances(p0, points):
|
|
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
|
|
|
|
def __call__(self, pts, *args, **kwargs):
|
|
|
|
if pts.shape[0] < self.k:
|
|
return pts
|
|
|
|
else:
|
|
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
|
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
|
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
|
distances = self.calc_distances(farthest_pts[0], pts)
|
|
for i in range(1, self.k):
|
|
farthest_pts_idx[i] = np.argmax(distances)
|
|
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
|
|
|
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
|
|
|
return farthest_pts_idx
|