From 12d36047eff120753c8c0a3bf2aa08d6a5e3a200 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Fri, 19 Jun 2020 08:17:35 +0200 Subject: [PATCH] Dataset Redone --- point_toolset/point_io.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/point_toolset/point_io.py b/point_toolset/point_io.py index 728b6bf..2c2f7db 100644 --- a/point_toolset/point_io.py +++ b/point_toolset/point_io.py @@ -9,16 +9,23 @@ class BatchToData(object): super(BatchToData, self).__init__() self.transforms = transforms if transforms else lambda x: x - def __call__(self, batch_norm: torch.Tensor, batch_pos: torch.Tensor, - batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None): + def __call__(self, batch_dict): # Convert to torch_geometric.data.Data type - # data = data.transpose(1, 2).contiguous() - batch_size, num_points, _ = batch_norm.shape # (batch_size, num_points, 3) - norm = batch_norm.reshape(batch_size * num_points, -1) - pos = batch_pos.reshape(batch_size * num_points, -1) - batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l - batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c + batch_pos = batch_dict['pos'] + batch_norm = batch_dict['norm'] + batch_y = batch_dict['y'] + batch_y_c = batch_dict['y_c'] + + batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3) + + batch_size, N, _ = batch_pos.shape # (batch_size, num_points, 3) + pos = batch_pos.view(batch_size * N, -1) + norm = batch_norm.view(batch_size * N, -1) if batch_norm is not None else batch_norm + + batch_y_l = batch_y.view(batch_size * N, -1) if batch_y is not None else batch_y + batch_y_c = batch_y_c.view(batch_size * N, -1) if batch_y_c is not None else batch_y_c + batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i