diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py index 45ce200..a9a7b7f 100644 --- a/point_toolset/point_io.py +++ b/point_toolset/point_io.py @@ -1,3 +1,5 @@ +from typing import Union + import torch from torch_geometric.data import Data @@ -7,15 +9,15 @@ class BatchToData(object): super(BatchToData, self).__init__() def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, - batch_y_l: torch.Tensor, batch_y_c: torch.Tensor): + batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None): # Convert to torch_geometric.data.Data type # data = data.transpose(1, 2).contiguous() batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3) x = batch_x.reshape(batch_size * num_points, -1) pos = batch_pos.reshape(batch_size * num_points, -1) - batch_y_l = batch_y_l.reshape(batch_size * num_points) - batch_y_c = batch_y_c.reshape(batch_size * num_points) + batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l + batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i diff --git a/point_toolset/sampling.py b/point_toolset/sampling.py index 0c2ee8f..ac2184f 100644 --- a/point_toolset/sampling.py +++ b/point_toolset/sampling.py @@ -19,12 +19,8 @@ class RandomSampling(_Sampler): super(RandomSampling, self).__init__(*args, **kwargs) def __call__(self, pts, *args, **kwargs): - if pts.shape[0] < self.k: - return pts - - else: - rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False) - return rnd_indexs + rnd_indexs = np.random.choice(np.arange(pts.shape[0]), min(self.k, pts.shape[0]), replace=False) + return rnd_indexs class FarthestpointSampling(_Sampler):