diff --git a/modules/geometric_blocks.py b/modules/geometric_blocks.py index dfd77d0..6a05068 100644 --- a/modules/geometric_blocks.py +++ b/modules/geometric_blocks.py @@ -6,6 +6,7 @@ from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_inte class SAModule(torch.nn.Module): + def __init__(self, ratio, r, nn): super(SAModule, self).__init__() self.ratio = ratio diff --git a/modules/util.py b/modules/util.py index 0b34277..1329ab2 100644 --- a/modules/util.py +++ b/modules/util.py @@ -84,12 +84,21 @@ class ShapeMixin: @property def shape(self): assert isinstance(self, (LightningBaseModule, nn.Module)) + + def get_out_shape(output): + return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] + if self.in_shape is not None: x = torch.randn(self.in_shape) # This is needed for BatchNorm shape checking x = torch.stack((x, x)) - output = self(x) - return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] + + y = self(x) + if isinstance(y, tuple): + shape = tuple([get_out_shape(y[i]) for i in range(len(y))]) + else: + shape = get_out_shape(y) + return shape else: return -1 diff --git a/point_toolset/point_transforms.py b/point_toolset/point_transforms.py new file mode 100644 index 0000000..e6ffcb9 --- /dev/null +++ b/point_toolset/point_transforms.py @@ -0,0 +1,26 @@ +import torch +from torch_geometric.transforms import NormalizeScale + + +class NormalizePositions(NormalizeScale): + + def __init__(self): + super(NormalizePositions, self).__init__() + + def __call__(self, data): + if torch.isnan(data.pos).any(): + print('debug') + + data = self.center(data) + if torch.isnan(data.pos).any(): + print('debug') + + scale = (1 / data.pos.abs().max()) * 0.999999 + if torch.isnan(scale).any() or torch.isinf(scale).any(): + print('debug') + + data.pos = data.pos * scale + if torch.isnan(data.pos).any(): + print('debug') + + return data diff --git a/utils/model_io.py b/utils/model_io.py index 8724c11..9e025fe 100644 --- a/utils/model_io.py +++ b/utils/model_io.py @@ -52,7 +52,7 @@ class ModelParameters(Namespace, Mapping): if name == 'stretch': return False else: - raise AttributeError(e) + return None _activations = dict( leaky_relu=nn.LeakyReLU,