eval running - offline logger implemented -> Test it!
This commit is contained in:
@@ -2,7 +2,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import ReLU
|
||||
|
||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool
|
||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||
|
||||
|
||||
class SAModule(torch.nn.Module):
|
||||
@@ -23,14 +23,15 @@ class SAModule(torch.nn.Module):
|
||||
|
||||
|
||||
class GlobalSAModule(nn.Module):
|
||||
def __init__(self, nn):
|
||||
def __init__(self, nn, channels=3):
|
||||
super(GlobalSAModule, self).__init__()
|
||||
self.nn = nn
|
||||
self.channels = channels
|
||||
|
||||
def forward(self, x, pos, batch):
|
||||
x = self.nn(torch.cat([x, pos], dim=1))
|
||||
x = global_max_pool(x, batch)
|
||||
pos = pos.new_zeros((x.size(0), 3))
|
||||
pos = pos.new_zeros((x.size(0), self.channels))
|
||||
batch = torch.arange(x.size(0), device=batch.device)
|
||||
return x, pos, batch
|
||||
|
||||
@@ -45,3 +46,17 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FPModule(torch.nn.Module):
|
||||
def __init__(self, k, nn):
|
||||
super(FPModule, self).__init__()
|
||||
self.k = k
|
||||
self.nn = nn
|
||||
|
||||
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
|
||||
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
|
||||
if x_skip is not None:
|
||||
x = torch.cat([x, x_skip], dim=1)
|
||||
x = self.nn(x)
|
||||
return x, pos_skip, batch_skip
|
Reference in New Issue
Block a user