ml_lib/modules/geometric_blocks.py
2020-05-26 21:44:56 +02:00

48 lines
1.4 KiB
Python

import torch
from torch import nn
from torch.nn import ReLU
from torch_geometric.nn import PointConv, fps, radius, global_max_pool
class SAModule(torch.nn.Module):
def __init__(self, ratio, r, nn):
super(SAModule, self).__init__()
self.ratio = ratio
self.r = r
self.conv = PointConv(nn)
def forward(self, x, pos, batch):
idx = fps(pos, batch, ratio=self.ratio)
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
max_num_neighbors=64)
edge_index = torch.stack([col, row], dim=0)
x = self.conv(x, (pos, pos[idx]), edge_index)
pos, batch = pos[idx], batch[idx]
return x, pos, batch
class GlobalSAModule(nn.Module):
def __init__(self, nn):
super(GlobalSAModule, self).__init__()
self.nn = nn
def forward(self, x, pos, batch):
x = self.nn(torch.cat([x, pos], dim=1))
x = global_max_pool(x, batch)
pos = pos.new_zeros((x.size(0), 3))
batch = torch.arange(x.size(0), device=batch.device)
return x, pos, batch
class MLP(nn.Module):
def __init__(self, channels, norm=True):
super(MLP, self).__init__()
self.net = nn.Sequential(*[
nn.Sequential(nn.Linear(channels[i - 1], channels[i]), ReLU(), nn.BatchNorm1d(channels[i]))
for i in range(1, len(channels))
]).double()
def forward(self, x, *args, **kwargs):
return self.net(x)