New Model running
This commit is contained in:
@ -84,12 +84,21 @@ class ShapeMixin:
|
||||
@property
|
||||
def shape(self):
|
||||
assert isinstance(self, (LightningBaseModule, nn.Module))
|
||||
|
||||
def get_out_shape(output):
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
x = torch.randn(self.in_shape)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
output = self(x)
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
y = self(x)
|
||||
if isinstance(y, tuple):
|
||||
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
|
||||
else:
|
||||
shape = get_out_shape(y)
|
||||
return shape
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
Reference in New Issue
Block a user