New Model running
This commit is contained in:
parent
aea34de964
commit
53aa11521d
@ -6,6 +6,7 @@ from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_inte
|
|||||||
|
|
||||||
|
|
||||||
class SAModule(torch.nn.Module):
|
class SAModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, ratio, r, nn):
|
def __init__(self, ratio, r, nn):
|
||||||
super(SAModule, self).__init__()
|
super(SAModule, self).__init__()
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
@ -84,12 +84,21 @@ class ShapeMixin:
|
|||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||||
|
|
||||||
|
def get_out_shape(output):
|
||||||
|
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||||
|
|
||||||
if self.in_shape is not None:
|
if self.in_shape is not None:
|
||||||
x = torch.randn(self.in_shape)
|
x = torch.randn(self.in_shape)
|
||||||
# This is needed for BatchNorm shape checking
|
# This is needed for BatchNorm shape checking
|
||||||
x = torch.stack((x, x))
|
x = torch.stack((x, x))
|
||||||
output = self(x)
|
|
||||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
y = self(x)
|
||||||
|
if isinstance(y, tuple):
|
||||||
|
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||||
|
else:
|
||||||
|
shape = get_out_shape(y)
|
||||||
|
return shape
|
||||||
else:
|
else:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
26
point_toolset/point_transforms.py
Normal file
26
point_toolset/point_transforms.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from torch_geometric.transforms import NormalizeScale
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizePositions(NormalizeScale):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(NormalizePositions, self).__init__()
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
if torch.isnan(data.pos).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
data = self.center(data)
|
||||||
|
if torch.isnan(data.pos).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
scale = (1 / data.pos.abs().max()) * 0.999999
|
||||||
|
if torch.isnan(scale).any() or torch.isinf(scale).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
data.pos = data.pos * scale
|
||||||
|
if torch.isnan(data.pos).any():
|
||||||
|
print('debug')
|
||||||
|
|
||||||
|
return data
|
@ -52,7 +52,7 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
if name == 'stretch':
|
if name == 'stretch':
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
raise AttributeError(e)
|
return None
|
||||||
|
|
||||||
_activations = dict(
|
_activations = dict(
|
||||||
leaky_relu=nn.LeakyReLU,
|
leaky_relu=nn.LeakyReLU,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user