Dataset rdy
This commit is contained in:
parent
a966321576
commit
b5e3e5aec1
@ -12,7 +12,7 @@ class TorchMelDataset(Dataset):
|
||||
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
|
||||
super(TorchMelDataset, self).__init__()
|
||||
self.sampling_rate = int(sampling_rate)
|
||||
self.audio_file_len = int(audio_file_len)
|
||||
self.audio_file_len = float(audio_file_len)
|
||||
if auto_pad_to_shape and sub_segment_len:
|
||||
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
|
||||
else:
|
||||
@ -33,7 +33,14 @@ class TorchMelDataset(Dataset):
|
||||
with self.path.open('rb') as mel_file:
|
||||
mel_spec = pickle.load(mel_file, fix_imports=True)
|
||||
start = self.offsets[item]
|
||||
duration = self.sub_segment_len if self.sub_segment_len and self.sub_segment_hop_len else mel_spec.shape[1]
|
||||
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
|
||||
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
|
||||
|
||||
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
|
||||
duration = self.sub_segment_len
|
||||
else:
|
||||
duration = mel_spec.shape[1]
|
||||
|
||||
snippet = mel_spec[:, start: start + duration]
|
||||
if self.transform:
|
||||
snippet = self.transform(snippet)
|
||||
|
@ -178,13 +178,13 @@ class UnitGenerator(Generator):
|
||||
return tensor
|
||||
|
||||
|
||||
class BaseEncoder(ShapeMixin, nn.Module):
|
||||
class BaseCNNEncoder(ShapeMixin, nn.Module):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||
filters: List[int] = None, kernels: List[int] = None, **kwargs):
|
||||
super(BaseEncoder, self).__init__()
|
||||
super(BaseCNNEncoder, self).__init__()
|
||||
assert filters, '"Filters" has to be a list of int'
|
||||
assert kernels, '"Kernels" has to be a list of int'
|
||||
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||
@ -227,11 +227,11 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
||||
return tensor
|
||||
|
||||
|
||||
class UnitEncoder(BaseEncoder):
|
||||
class UnitCNNEncoder(BaseCNNEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.update(use_norm=True)
|
||||
super(UnitEncoder, self).__init__(*args, **kwargs)
|
||||
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
|
||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
@ -243,10 +243,10 @@ class UnitEncoder(BaseEncoder):
|
||||
return c1, c2, c3, l1
|
||||
|
||||
|
||||
class VariationalEncoder(BaseEncoder):
|
||||
class VariationalCNNEncoder(BaseCNNEncoder):
|
||||
# noinspection PyUnresolvedReferences
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
||||
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||
@ -258,22 +258,22 @@ class VariationalEncoder(BaseEncoder):
|
||||
return mu + eps*std
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(VariationalEncoder, self).forward(x)
|
||||
tensor = super(VariationalCNNEncoder, self).forward(x)
|
||||
mu = self.mu(tensor)
|
||||
logvar = self.logvar(tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return mu, logvar, z
|
||||
|
||||
|
||||
class Encoder(BaseEncoder):
|
||||
class CNNEncoder(BaseCNNEncoder):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Encoder, self).__init__(*args, **kwargs)
|
||||
super(CNNEncoder, self).__init__(*args, **kwargs)
|
||||
|
||||
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = super(Encoder, self).forward(x)
|
||||
tensor = super(CNNEncoder, self).forward(x)
|
||||
tensor = self.l1(tensor)
|
||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||
return tensor
|
||||
|
@ -31,6 +31,8 @@ class ModelParameters(Namespace, Mapping):
|
||||
activation=self.__getattribute__('activation')
|
||||
)
|
||||
)
|
||||
# Get rid of paramters that
|
||||
paramter_mapping.__delitem__('in_shape')
|
||||
|
||||
return paramter_mapping
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user