Dataset for whole pointclouds with farthest point sampling _incomplete_
This commit is contained in:
parent
fcd5ee4d29
commit
196b1af7ae
@ -3,4 +3,10 @@ from torch.utils.data import Dataset
|
|||||||
|
|
||||||
class TemplateDataset(Dataset):
|
class TemplateDataset(Dataset):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TemplateDataset, self).__init__()
|
super(TemplateDataset, self).__init__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return item
|
||||||
|
0
point_toolset/__init__.py
Normal file
0
point_toolset/__init__.py
Normal file
24
point_toolset/sampling.py
Normal file
24
point_toolset/sampling.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class FarthestpointSampling():
|
||||||
|
|
||||||
|
def __init__(self, K):
|
||||||
|
self.k = K
|
||||||
|
|
||||||
|
def __call__(self, pts, *args, **kwargs):
|
||||||
|
|
||||||
|
if pts.shape[0] < self.k:
|
||||||
|
return pts
|
||||||
|
|
||||||
|
def calc_distances(p0, points):
|
||||||
|
return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)
|
||||||
|
|
||||||
|
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||||
|
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||||
|
distances = calc_distances(farthest_pts[0], pts)
|
||||||
|
for i in range(1, self.k):
|
||||||
|
farthest_pts[i] = pts[np.argmax(distances)]
|
||||||
|
distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))
|
||||||
|
|
||||||
|
return farthest_pts
|
Loading…
x
Reference in New Issue
Block a user