Kurz vorm durchdrehen

This commit is contained in:
Si11ium
2020-03-11 17:10:19 +01:00
parent 1b5a7dc69e
commit 1f4edae95c
12 changed files with 157 additions and 93 deletions

View File

@@ -4,11 +4,11 @@ import torch
from torch import nn
from lib.modules.utils import AutoPad, Interpolate
#
# Sub - Modules
###################
class ConvModule(nn.Module):
@property
@@ -60,7 +60,7 @@ class DeConvModule(nn.Module):
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=False,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
use_bias=True, normalize=False):
use_bias=True, use_norm=False):
super(DeConvModule, self).__init__()
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding
@@ -70,7 +70,7 @@ class DeConvModule(nn.Module):
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)

View File

@@ -24,7 +24,7 @@ class Generator(nn.Module):
self.lat_dim = lat_dim
self.dropout = dropout
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
# re_shape = (self.lat_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
self.flat = Flatten(to=re_shape)

View File

@@ -67,6 +67,23 @@ class AutoPad(nn.Module):
return x
class WeightInit:
def __init__(self, in_place_init_function):
self.in_place_init_function = in_place_init_function
def __call__(self, m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
if m.weight.ndim < 2:
m.weight.data.fill_(0.01)
else:
self.in_place_init_function(m.weight)
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
@@ -128,15 +145,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
def test_epoch_end(self, outputs):
raise NotImplementedError
def init_weights(self):
def _weight_init(m):
if hasattr(m, 'weight'):
if isinstance(m.weight, torch.Tensor):
torch.nn.init.xavier_uniform_(m.weight)
if hasattr(m, 'bias'):
if isinstance(m.bias, torch.Tensor):
m.bias.data.fill_(0.01)
self.apply(_weight_init)
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
self.apply(weight_initializer)
# Dataloaders
# ================================================================================