New Dataset for per spatial cluster training
This commit is contained in:
		| @@ -1,3 +1,5 @@ | ||||
| from typing import Union | ||||
|  | ||||
| import torch | ||||
| from torch_geometric.data import Data | ||||
|  | ||||
| @@ -7,15 +9,15 @@ class BatchToData(object): | ||||
|         super(BatchToData, self).__init__() | ||||
|  | ||||
|     def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, | ||||
|                  batch_y_l: torch.Tensor, batch_y_c: torch.Tensor): | ||||
|                  batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None): | ||||
|         # Convert to torch_geometric.data.Data type | ||||
|         # data = data.transpose(1, 2).contiguous() | ||||
|         batch_size, num_points, _ = batch_x.shape  # (batch_size, num_points, 3) | ||||
|  | ||||
|         x = batch_x.reshape(batch_size * num_points, -1) | ||||
|         pos = batch_pos.reshape(batch_size * num_points, -1) | ||||
|         batch_y_l = batch_y_l.reshape(batch_size * num_points) | ||||
|         batch_y_c = batch_y_c.reshape(batch_size * num_points) | ||||
|         batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l | ||||
|         batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c | ||||
|         batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) | ||||
|         for i in range(batch_size): | ||||
|             batch[i] = i | ||||
|   | ||||
| @@ -19,12 +19,8 @@ class RandomSampling(_Sampler): | ||||
|         super(RandomSampling, self).__init__(*args, **kwargs) | ||||
|  | ||||
|     def __call__(self, pts, *args, **kwargs): | ||||
|         if pts.shape[0] < self.k: | ||||
|             return pts | ||||
|  | ||||
|         else: | ||||
|             rnd_indexs = np.random.choice(np.arange(pts.shape[0]), self.k, replace=False) | ||||
|             return rnd_indexs | ||||
|         rnd_indexs = np.random.choice(np.arange(pts.shape[0]), min(self.k, pts.shape[0]), replace=False) | ||||
|         return rnd_indexs | ||||
|  | ||||
|  | ||||
| class FarthestpointSampling(_Sampler): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Si11ium
					Si11ium