Dataset rdy
This commit is contained in:
@ -178,13 +178,13 @@ class UnitGenerator(Generator):
|
||||
return tensor
|
||||
|
||||
|
||||
class BaseEncoder(ShapeMixin, nn.Module):
|
||||
class BaseCNNEncoder(ShapeMixin, nn.Module):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||
filters: List[int] = None, kernels: List[int] = None, **kwargs):
|
||||
super(BaseEncoder, self).__init__()
|
||||
super(BaseCNNEncoder, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int'
|
||||
assert kernels, '"Kernels" has to be a list of int'
|
||||
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||
@ -227,11 +227,11 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
class UnitEncoder(BaseEncoder):
|
||||
class UnitCNNEncoder(BaseCNNEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(use_norm=True)
|
||||
super(UnitEncoder, self).__init__(*args, **kwargs)
|
||||
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
|
||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
@ -243,10 +243,10 @@ class UnitEncoder(BaseEncoder):
|
||||
return c1, c2, c3, l1
|
||||
|
||||
|
||||
class VariationalEncoder(BaseEncoder):
|
||||
class VariationalCNNEncoder(BaseCNNEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
||||
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
@ -258,22 +258,22 @@ class VariationalEncoder(BaseEncoder):
|
||||
return mu + eps*std
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(VariationalEncoder, self).forward(x)
|
||||
tensor = super(VariationalCNNEncoder, self).forward(x)
|
||||
mu = self.mu(tensor)
|
||||
logvar = self.logvar(tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return mu, logvar, z
|
||||
|
||||
|
||||
class Encoder(BaseEncoder):
|
||||
class CNNEncoder(BaseCNNEncoder):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Encoder, self).__init__(*args, **kwargs)
|
||||
super(CNNEncoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(Encoder, self).forward(x)
|
||||
tensor = super(CNNEncoder, self).forward(x)
|
||||
tensor = self.l1(tensor)
|
||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user