New Model running
This commit is contained in:
parent
aea34de964
commit
53aa11521d
@ -6,6 +6,7 @@ from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_inte
|
||||
|
||||
|
||||
class SAModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, ratio, r, nn):
|
||||
super(SAModule, self).__init__()
|
||||
self.ratio = ratio
|
||||
|
@ -84,12 +84,21 @@ class ShapeMixin:
|
||||
@property
|
||||
def shape(self):
|
||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
x = torch.randn(self.in_shape)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
output = self(x)
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
y = self(x)
|
||||
if isinstance(y, tuple):
|
||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||
else:
|
||||
shape = get_out_shape(y)
|
||||
return shape
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
26
point_toolset/point_transforms.py
Normal file
26
point_toolset/point_transforms.py
Normal file
@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from torch_geometric.transforms import NormalizeScale
|
||||
|
||||
|
||||
class NormalizePositions(NormalizeScale):
|
||||
|
||||
def __init__(self):
|
||||
super(NormalizePositions, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
data = self.center(data)
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
scale = (1 / data.pos.abs().max()) * 0.999999
|
||||
if torch.isnan(scale).any() or torch.isinf(scale).any():
|
||||
print('debug')
|
||||
|
||||
data.pos = data.pos * scale
|
||||
if torch.isnan(data.pos).any():
|
||||
print('debug')
|
||||
|
||||
return data
|
@ -52,7 +52,7 @@ class ModelParameters(Namespace, Mapping):
|
||||
if name == 'stretch':
|
||||
return False
|
||||
else:
|
||||
raise AttributeError(e)
|
||||
return None
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
|
Loading…
x
Reference in New Issue
Block a user