Dataset for whole pointclouds with farthest point sampling _incomplete_

This commit is contained in:
Si11ium 2020-05-19 17:15:01 +02:00
parent fcd5ee4d29
commit 196b1af7ae
3 changed files with 31 additions and 1 deletions

View File

@ -3,4 +3,10 @@ from torch.utils.data import Dataset
class TemplateDataset(Dataset): class TemplateDataset(Dataset):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TemplateDataset, self).__init__() super(TemplateDataset, self).__init__()
def __len__(self):
pass
def __getitem__(self, item):
return item

View File

24
point_toolset/sampling.py Normal file
View File

@ -0,0 +1,24 @@
import numpy as np
class FarthestpointSampling():
def __init__(self, K):
self.k = K
def __call__(self, pts, *args, **kwargs):
if pts.shape[0] < self.k:
return pts
def calc_distances(p0, points):
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
farthest_pts = np.zeros((self.k, pts.shape[1]))
farthest_pts[0] = pts[np.random.randint(len(pts))]
distances = calc_distances(farthest_pts[0], pts)
for i in range(1, self.k):
farthest_pts[i] = pts[np.argmax(distances)]
distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))
return farthest_pts