New Dataset for per spatial cluster training
This commit is contained in:
parent
2acf91335f
commit
d3fa32ae7b
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
@ -7,15 +9,15 @@ class BatchToData(object):
|
|||||||
super(BatchToData, self).__init__()
|
super(BatchToData, self).__init__()
|
||||||
|
|
||||||
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor,
|
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor,
|
||||||
batch_y_l: torch.Tensor, batch_y_c: torch.Tensor):
|
batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None):
|
||||||
# Convert to torch_geometric.data.Data type
|
# Convert to torch_geometric.data.Data type
|
||||||
# data = data.transpose(1, 2).contiguous()
|
# data = data.transpose(1, 2).contiguous()
|
||||||
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
||||||
|
|
||||||
x = batch_x.reshape(batch_size * num_points, -1)
|
x = batch_x.reshape(batch_size * num_points, -1)
|
||||||
pos = batch_pos.reshape(batch_size * num_points, -1)
|
pos = batch_pos.reshape(batch_size * num_points, -1)
|
||||||
batch_y_l = batch_y_l.reshape(batch_size * num_points)
|
batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l
|
||||||
batch_y_c = batch_y_c.reshape(batch_size * num_points)
|
batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c
|
||||||
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
batch[i] = i
|
batch[i] = i
|
||||||
|
@ -19,12 +19,8 @@ class RandomSampling(_Sampler):
|
|||||||
super(RandomSampling, self).__init__(*args, **kwargs)
|
super(RandomSampling, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def __call__(self, pts, *args, **kwargs):
|
def __call__(self, pts, *args, **kwargs):
|
||||||
if pts.shape[0] < self.k:
|
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), min(self.k, pts.shape[0]), replace=False)
|
||||||
return pts
|
return rnd_indexs
|
||||||
|
|
||||||
else:
|
|
||||||
rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False)
|
|
||||||
return rnd_indexs
|
|
||||||
|
|
||||||
|
|
||||||
class FarthestpointSampling(_Sampler):
|
class FarthestpointSampling(_Sampler):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user