New Model running

This commit is contained in:
Si11ium 2020-06-23 14:37:33 +02:00
parent aea34de964
commit 53aa11521d
4 changed files with 39 additions and 3 deletions

View File

@ -6,6 +6,7 @@ from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_inte
class SAModule(torch.nn.Module): class SAModule(torch.nn.Module):
def __init__(self, ratio, r, nn): def __init__(self, ratio, r, nn):
super(SAModule, self).__init__() super(SAModule, self).__init__()
self.ratio = ratio self.ratio = ratio

View File

@ -84,12 +84,21 @@ class ShapeMixin:
@property @property
def shape(self): def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module)) assert isinstance(self, (LightningBaseModule, nn.Module))
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None: if self.in_shape is not None:
x = torch.randn(self.in_shape) x = torch.randn(self.in_shape)
# This is needed for BatchNorm shape checking # This is needed for BatchNorm shape checking
x = torch.stack((x, x)) x = torch.stack((x, x))
output = self(x)
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
else:
shape = get_out_shape(y)
return shape
else: else:
return -1 return -1

View File

@ -0,0 +1,26 @@
import torch
from torch_geometric.transforms import NormalizeScale
class NormalizePositions(NormalizeScale):
def __init__(self):
super(NormalizePositions, self).__init__()
def __call__(self, data):
if torch.isnan(data.pos).any():
print('debug')
data = self.center(data)
if torch.isnan(data.pos).any():
print('debug')
scale = (1 / data.pos.abs().max()) * 0.999999
if torch.isnan(scale).any() or torch.isinf(scale).any():
print('debug')
data.pos = data.pos * scale
if torch.isnan(data.pos).any():
print('debug')
return data

View File

@ -52,7 +52,7 @@ class ModelParameters(Namespace, Mapping):
if name == 'stretch': if name == 'stretch':
return False return False
else: else:
raise AttributeError(e) return None
_activations = dict( _activations = dict(
leaky_relu=nn.LeakyReLU, leaky_relu=nn.LeakyReLU,