eval running - offline logger implemented -> Test it!

This commit is contained in:
Si11ium
2020-05-30 18:12:41 +02:00
parent 77ea043907
commit 5987efb169
9 changed files with 626 additions and 17 deletions

24
point_toolset/point_io.py Normal file
View File

@ -0,0 +1,24 @@
import torch
from torch_geometric.data import Data
class BatchToData(object):
def __init__(self):
super(BatchToData, self).__init__()
def __call__(self, batch_x: torch.Tensor, batch_pos: torch.Tensor, batch_y: torch.Tensor):
# Convert to torch_geometric.data.Data type
# data = data.transpose(1, 2).contiguous()
batch_size, num_points, _ = batch_x.shape # (batch_size, num_points, 3)
x = batch_x.reshape(batch_size * num_points, -1)
pos = batch_pos.reshape(batch_size * num_points, -1)
batch_y = batch_y.reshape(batch_size * num_points)
batch = torch.zeros((batch_size, num_points), device=pos.device, dtype=torch.long)
for i in range(batch_size):
batch[i] = i
batch = batch.view(-1)
data = Data()
data.x, data.pos, data.batch, data.y = x, pos, batch, batch_y
return data