New Model running
This commit is contained in:
@@ -6,6 +6,7 @@ from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_inte
|
||||
|
||||
|
||||
class SAModule(torch.nn.Module):
|
||||
|
||||
def __init__(self, ratio, r, nn):
|
||||
super(SAModule, self).__init__()
|
||||
self.ratio = ratio
|
||||
|
@@ -84,12 +84,21 @@ class ShapeMixin:
|
||||
@property
|
||||
def shape(self):
|
||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
x = torch.randn(self.in_shape)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
output = self(x)
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
y = self(x)
|
||||
if isinstance(y, tuple):
|
||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||
else:
|
||||
shape = get_out_shape(y)
|
||||
return shape
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
Reference in New Issue
Block a user