Dataset rdy

This commit is contained in:
Steffen Illium
2021-02-16 10:18:03 +01:00
parent a966321576
commit b5e3e5aec1
3 changed files with 21 additions and 12 deletions

View File

@ -178,13 +178,13 @@ class UnitGenerator(Generator):
return tensor
class BaseEncoder(ShapeMixin, nn.Module):
class BaseCNNEncoder(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None, kernels: List[int] = None, **kwargs):
super(BaseEncoder, self).__init__()
super(BaseCNNEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int'
assert kernels, '"Kernels" has to be a list of int'
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
@ -227,11 +227,11 @@ class BaseEncoder(ShapeMixin, nn.Module):
return tensor
class UnitEncoder(BaseEncoder):
class UnitCNNEncoder(BaseCNNEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
kwargs.update(use_norm=True)
super(UnitEncoder, self).__init__(*args, **kwargs)
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
def forward(self, x):
@ -243,10 +243,10 @@ class UnitEncoder(BaseEncoder):
return c1, c2, c3, l1
class VariationalEncoder(BaseEncoder):
class VariationalCNNEncoder(BaseCNNEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(VariationalEncoder, self).__init__(*args, **kwargs)
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
@ -258,22 +258,22 @@ class VariationalEncoder(BaseEncoder):
return mu + eps*std
def forward(self, x):
tensor = super(VariationalEncoder, self).forward(x)
tensor = super(VariationalCNNEncoder, self).forward(x)
mu = self.mu(tensor)
logvar = self.logvar(tensor)
z = self.reparameterize(mu, logvar)
return mu, logvar, z
class Encoder(BaseEncoder):
class CNNEncoder(BaseCNNEncoder):
def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs)
super(CNNEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
def forward(self, x):
tensor = super(Encoder, self).forward(x)
tensor = super(CNNEncoder, self).forward(x)
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor