Offline Datasets res net optionality
This commit is contained in:
@@ -22,7 +22,7 @@ class Flatten(nn.Module):
|
||||
try:
|
||||
x = torch.randn(self.in_shape).unsqueeze(0)
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
Reference in New Issue
Block a user