Al Lot
This commit is contained in:
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from operator import mul
|
||||
from torch import nn
|
||||
from torch import functional as F
|
||||
from torch.nn import functional as F
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
@ -92,7 +92,14 @@ class ShapeMixin:
|
||||
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
|
||||
|
||||
if self.in_shape is not None:
|
||||
x = torch.randn(self.in_shape)
|
||||
try:
|
||||
device = self.device
|
||||
except AttributeError:
|
||||
try:
|
||||
device = next(self.parameters()).device
|
||||
except StopIteration:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
x = torch.randn(self.in_shape, device=device)
|
||||
# This is needed for BatchNorm shape checking
|
||||
x = torch.stack((x, x))
|
||||
|
||||
@ -248,7 +255,7 @@ class Splitter(nn.Module):
|
||||
def __init__(self, in_shape, n, dim=-1):
|
||||
super(Splitter, self).__init__()
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
|
||||
self.n = n
|
||||
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||
|
||||
|
Reference in New Issue
Block a user