Kurz vorm durchdrehen
This commit is contained in:
@@ -4,11 +4,11 @@ import torch
|
||||
from torch import nn
|
||||
from lib.modules.utils import AutoPad, Interpolate
|
||||
|
||||
|
||||
#
|
||||
# Sub - Modules
|
||||
###################
|
||||
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
|
||||
@property
|
||||
@@ -60,7 +60,7 @@ class DeConvModule(nn.Module):
|
||||
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
dropout: Union[int, float] = 0, autopad=False,
|
||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
|
||||
use_bias=True, normalize=False):
|
||||
use_bias=True, use_norm=False):
|
||||
super(DeConvModule, self).__init__()
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
self.padding = conv_padding
|
||||
@@ -70,7 +70,7 @@ class DeConvModule(nn.Module):
|
||||
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if normalize else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride)
|
||||
|
||||
@@ -24,7 +24,7 @@ class Generator(nn.Module):
|
||||
self.lat_dim = lat_dim
|
||||
self.dropout = dropout
|
||||
self.l1 = nn.Linear(self.lat_dim, reduce(mul, re_shape), bias=use_bias)
|
||||
# re_shape = (self.lat_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||
|
||||
self.flat = Flatten(to=re_shape)
|
||||
|
||||
|
||||
@@ -67,6 +67,23 @@ class AutoPad(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WeightInit:
|
||||
|
||||
def __init__(self, in_place_init_function):
|
||||
self.in_place_init_function = in_place_init_function
|
||||
|
||||
def __call__(self, m):
|
||||
if hasattr(m, 'weight'):
|
||||
if isinstance(m.weight, torch.Tensor):
|
||||
if m.weight.ndim < 2:
|
||||
m.weight.data.fill_(0.01)
|
||||
else:
|
||||
self.in_place_init_function(m.weight)
|
||||
if hasattr(m, 'bias'):
|
||||
if isinstance(m.bias, torch.Tensor):
|
||||
m.bias.data.fill_(0.01)
|
||||
|
||||
|
||||
class LightningBaseModule(pl.LightningModule, ABC):
|
||||
|
||||
@classmethod
|
||||
@@ -128,15 +145,9 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
def test_epoch_end(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_weights(self):
|
||||
def _weight_init(m):
|
||||
if hasattr(m, 'weight'):
|
||||
if isinstance(m.weight, torch.Tensor):
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if hasattr(m, 'bias'):
|
||||
if isinstance(m.bias, torch.Tensor):
|
||||
m.bias.data.fill_(0.01)
|
||||
self.apply(_weight_init)
|
||||
def init_weights(self, in_place_init_func_=nn.init.xavier_uniform_):
|
||||
weight_initializer = WeightInit(in_place_init_function=in_place_init_func_)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
|
||||
Reference in New Issue
Block a user