eval running - offline logger implemented -> Test it!
This commit is contained in:
24
point_toolset/point_io.py
Normal file
24
point_toolset/point_io.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
|
||||
class BatchToData(object):
|
||||
def __init__(self):
|
||||
super(BatchToData, self).__init__()
|
||||
|
||||
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor):
|
||||
# Convert to torch_geometric.data.Data type
|
||||
# data = data.transpose(1, 2).contiguous()
|
||||
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
||||
|
||||
x = batch_x.reshape(batch_size * num_points, -1)
|
||||
pos = batch_pos.reshape(batch_size * num_points, -1)
|
||||
batch_y = batch_y.reshape(batch_size * num_points)
|
||||
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
||||
for i in range(batch_size):
|
||||
batch[i] = i
|
||||
batch = batch.view(-1)
|
||||
|
||||
data = Data()
|
||||
data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y
|
||||
return data
|
||||
@@ -1,10 +1,36 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FarthestpointSampling():
|
||||
class _Sampler(ABC):
|
||||
|
||||
def __init__(self, K):
|
||||
def __init__(self, K, **kwargs):
|
||||
self.k = K
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomSampling(_Sampler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RandomSampling, self).__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, pts, *args, **kwargs):
|
||||
if pts.shape[0] < self.k:
|
||||
return pts
|
||||
|
||||
else:
|
||||
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False)
|
||||
return rnd_indexs
|
||||
|
||||
|
||||
class FarthestpointSampling(_Sampler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FarthestpointSampling, self).__init__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calc_distances(p0, points):
|
||||
@@ -15,14 +41,15 @@ class FarthestpointSampling():
|
||||
if pts.shape[0] < self.k:
|
||||
return pts
|
||||
|
||||
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
||||
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||
distances = self.calc_distances(farthest_pts[0], pts)
|
||||
for i in range(1, self.k):
|
||||
farthest_pts_idx[i] = np.argmax(distances)
|
||||
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
||||
else:
|
||||
farthest_pts = np.zeros((self.k, pts.shape[1]))
|
||||
farthest_pts_idx = np.zeros(self.k, dtype=np.int)
|
||||
farthest_pts[0] = pts[np.random.randint(len(pts))]
|
||||
distances = self.calc_distances(farthest_pts[0], pts)
|
||||
for i in range(1, self.k):
|
||||
farthest_pts_idx[i] = np.argmax(distances)
|
||||
farthest_pts[i] = pts[farthest_pts_idx[i]]
|
||||
|
||||
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
||||
distances = np.minimum(distances, self.calc_distances(farthest_pts[i], pts))
|
||||
|
||||
return farthest_pts_idx
|
||||
return farthest_pts_idx
|
||||
|
||||
Reference in New Issue
Block a user