New Model running

This commit is contained in:
Si11ium
2020-06-23 14:37:33 +02:00
parent aea34de964
commit 53aa11521d
4 changed files with 39 additions and 3 deletions

View File

@@ -6,6 +6,7 @@ from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_inte
class SAModule(torch.nn.Module):
def __init__(self, ratio, r, nn):
super(SAModule, self).__init__()
self.ratio = ratio

View File

@@ -84,12 +84,21 @@ class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
x = torch.randn(self.in_shape)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
output = self(x)
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
else:
shape = get_out_shape(y)
return shape
else:
return -1