New Model running

This commit is contained in:
Si11ium
2020-06-23 14:37:33 +02:00
parent aea34de964
commit 53aa11521d
4 changed files with 39 additions and 3 deletions

View File

@ -84,12 +84,21 @@ class ShapeMixin:
@property
def shape(self):
assert isinstance(self, (LightningBaseModule, nn.Module))
def get_out_shape(output):
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
x = torch.randn(self.in_shape)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
output = self(x)
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
y = self(x)
if isinstance(y, tuple):
shape = tuple([get_out_shape(y[i]) for i in range(len(y))])
else:
shape = get_out_shape(y)
return shape
else:
return -1