This commit is contained in:
Si11ium
2020-10-07 15:21:45 +02:00
parent 5848b528f0
commit f296ba78b9
6 changed files with 78 additions and 39 deletions

View File

@ -6,7 +6,7 @@ from pathlib import Path
import torch
from operator import mul
from torch import nn
from torch import functional as F
from torch.nn import functional as F
import pytorch_lightning as pl
@ -92,7 +92,14 @@ class ShapeMixin:
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
if self.in_shape is not None:
x = torch.randn(self.in_shape)
try:
device = self.device
except AttributeError:
try:
device = next(self.parameters()).device
except StopIteration:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(self.in_shape, device=device)
# This is needed for BatchNorm shape checking
x = torch.stack((x, x))
@ -248,7 +255,7 @@ class Splitter(nn.Module):
def __init__(self, in_shape, n, dim=-1):
super(Splitter, self).__init__()
self.in_shape = in_shape
self.in_shape = (in_shape, ) if isinstance(in_shape, int) else in_shape
self.n = n
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)