Normalization and transforms for batch_to_data class
This commit is contained in:
@ -5,16 +5,17 @@ from torch_geometric.data import Data
|
|||||||
|
|
||||||
|
|
||||||
class BatchToData(object):
|
class BatchToData(object):
|
||||||
def __init__(self):
|
def __init__(self, transforms=None):
|
||||||
super(BatchToData, self).__init__()
|
super(BatchToData, self).__init__()
|
||||||
|
self.transforms = transforms if transforms else lambda x: x
|
||||||
|
|
||||||
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor,
|
def __call__(self, batch_norm: torch.Tensor, batch_pos: torch.Tensor,
|
||||||
batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None):
|
batch_y_l: Union[torch.Tensor, None] = None, batch_y_c: Union[torch.Tensor, None] = None):
|
||||||
# Convert to torch_geometric.data.Data type
|
# Convert to torch_geometric.data.Data type
|
||||||
# data = data.transpose(1, 2).contiguous()
|
# data = data.transpose(1, 2).contiguous()
|
||||||
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
|
batch_size, num_points, _ = batch_norm.shape # (batch_size, num_points, 3)
|
||||||
|
|
||||||
x = batch_x.reshape(batch_size * num_points, -1)
|
norm = batch_norm.reshape(batch_size * num_points, -1)
|
||||||
pos = batch_pos.reshape(batch_size * num_points, -1)
|
pos = batch_pos.reshape(batch_size * num_points, -1)
|
||||||
batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l
|
batch_y_l = batch_y_l.reshape(batch_size * num_points) if batch_y_l is not None else batch_y_l
|
||||||
batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c
|
batch_y_c = batch_y_c.reshape(batch_size * num_points) if batch_y_c is not None else batch_y_c
|
||||||
@ -24,5 +25,8 @@ class BatchToData(object):
|
|||||||
batch = batch.view(-1)
|
batch = batch.view(-1)
|
||||||
|
|
||||||
data = Data()
|
data = Data()
|
||||||
data.x, data.pos, data.batch, data.yl, data.yc = x, pos, batch, batch_y_l, batch_y_c
|
data.norm, data.pos, data.batch, data.yl, data.yc = norm, pos, batch, batch_y_l, batch_y_c
|
||||||
|
|
||||||
|
data = self.transforms(data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
Reference in New Issue
Block a user