import torch from torch_geometric.data import Data class BatchToData(object): def __init__(self, transforms=None): super(BatchToData, self).__init__() self.transforms = transforms if transforms else lambda x: x def __call__(self, batch_dict): # Convert to torch_geometric.data.Data type batch_pos = batch_dict['pos'] batch_norm = batch_dict.get('norm', None) batch_y = batch_dict.get('y', None) batch_y_c = batch_dict.get('y_c', None) batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3) batch_size, N, _ = batch_pos.shape # (batch_size, num_points, 3) pos = batch_pos.view(batch_size * N, -1) norm = batch_norm.view(batch_size * N, -1) if batch_norm is not None else batch_norm batch_y_l = batch_y.view(batch_size * N, -1) if batch_y is not None else batch_y batch_y_c = batch_y_c.view(batch_size * N, -1) if batch_y_c is not None else batch_y_c batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) for i in range(batch_size): batch[i] = i batch = batch.view(-1) data = Data() data.norm, data.pos, data.batch, data.yl, data.yc = norm, pos, batch, batch_y_l, batch_y_c data = self.transforms(data) return data