Dataset Redone

This commit is contained in:
Si11ium 2020-06-19 08:17:35 +02:00
parent 76308888e0
commit 12d36047ef

View File

@ -9,16 +9,23 @@ class BatchToData(object):
super(BatchToData, self).__init__() super(BatchToData, self).__init__()
self.transforms = transforms if transforms else lambda x: x self.transforms = transforms if transforms else lambda x: x
def __call__(self, batch_norm: torch.Tensor, batch_pos: torch.Tensor, def __call__(self, batch_dict):
batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None):
# Convert to torch_geometric.data.Data type # Convert to torch_geometric.data.Data type
# data = data.transpose(1, 2).contiguous()
batch_size, num_points, _ = batch_norm.shape # (batch_size, num_points, 3)
norm = batch_norm.reshape(batch_size * num_points, -1) batch_pos = batch_dict['pos']
pos = batch_pos.reshape(batch_size * num_points, -1) batch_norm = batch_dict['norm']
batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l batch_y = batch_dict['y']
batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c batch_y_c = batch_dict['y_c']
batch_size, num_points, _ = batch_pos.shape # (batch_size, num_points, 3)
batch_size, N, _ = batch_pos.shape # (batch_size, num_points, 3)
pos = batch_pos.view(batch_size * N, -1)
norm = batch_norm.view(batch_size * N, -1) if batch_norm is not None else batch_norm
batch_y_l = batch_y.view(batch_size * N, -1) if batch_y is not None else batch_y
batch_y_c = batch_y_c.view(batch_size * N, -1) if batch_y_c is not None else batch_y_c
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long) batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
for i in range(batch_size): for i in range(batch_size):
batch[i] = i batch[i] = i