Dataset rdy
This commit is contained in:
parent
a966321576
commit
b5e3e5aec1
@ -12,7 +12,7 @@ class TorchMelDataset(Dataset):
|
|||||||
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
|
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
|
||||||
super(TorchMelDataset, self).__init__()
|
super(TorchMelDataset, self).__init__()
|
||||||
self.sampling_rate = int(sampling_rate)
|
self.sampling_rate = int(sampling_rate)
|
||||||
self.audio_file_len = int(audio_file_len)
|
self.audio_file_len = float(audio_file_len)
|
||||||
if auto_pad_to_shape and sub_segment_len:
|
if auto_pad_to_shape and sub_segment_len:
|
||||||
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
|
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
|
||||||
else:
|
else:
|
||||||
@ -33,7 +33,14 @@ class TorchMelDataset(Dataset):
|
|||||||
with self.path.open('rb') as mel_file:
|
with self.path.open('rb') as mel_file:
|
||||||
mel_spec = pickle.load(mel_file, fix_imports=True)
|
mel_spec = pickle.load(mel_file, fix_imports=True)
|
||||||
start = self.offsets[item]
|
start = self.offsets[item]
|
||||||
duration = self.sub_segment_len if self.sub_segment_len and self.sub_segment_hop_len else mel_spec.shape[1]
|
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
|
||||||
|
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
|
||||||
|
|
||||||
|
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
|
||||||
|
duration = self.sub_segment_len
|
||||||
|
else:
|
||||||
|
duration = mel_spec.shape[1]
|
||||||
|
|
||||||
snippet = mel_spec[:, start: start + duration]
|
snippet = mel_spec[:, start: start + duration]
|
||||||
if self.transform:
|
if self.transform:
|
||||||
snippet = self.transform(snippet)
|
snippet = self.transform(snippet)
|
||||||
|
@ -178,13 +178,13 @@ class UnitGenerator(Generator):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class BaseEncoder(ShapeMixin, nn.Module):
|
class BaseCNNEncoder(ShapeMixin, nn.Module):
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||||
filters: List[int] = None, kernels: List[int] = None, **kwargs):
|
filters: List[int] = None, kernels: List[int] = None, **kwargs):
|
||||||
super(BaseEncoder, self).__init__()
|
super(BaseCNNEncoder, self).__init__()
|
||||||
assert filters, '"Filters" has to be a list of int'
|
assert filters, '"Filters" has to be a list of int'
|
||||||
assert kernels, '"Kernels" has to be a list of int'
|
assert kernels, '"Kernels" has to be a list of int'
|
||||||
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||||
@ -227,11 +227,11 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class UnitEncoder(BaseEncoder):
|
class UnitCNNEncoder(BaseCNNEncoder):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
kwargs.update(use_norm=True)
|
kwargs.update(use_norm=True)
|
||||||
super(UnitEncoder, self).__init__(*args, **kwargs)
|
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
|
||||||
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -243,10 +243,10 @@ class UnitEncoder(BaseEncoder):
|
|||||||
return c1, c2, c3, l1
|
return c1, c2, c3, l1
|
||||||
|
|
||||||
|
|
||||||
class VariationalEncoder(BaseEncoder):
|
class VariationalCNNEncoder(BaseCNNEncoder):
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(VariationalEncoder, self).__init__(*args, **kwargs)
|
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
|
||||||
@ -258,22 +258,22 @@ class VariationalEncoder(BaseEncoder):
|
|||||||
return mu + eps*std
|
return mu + eps*std
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = super(VariationalEncoder, self).forward(x)
|
tensor = super(VariationalCNNEncoder, self).forward(x)
|
||||||
mu = self.mu(tensor)
|
mu = self.mu(tensor)
|
||||||
logvar = self.logvar(tensor)
|
logvar = self.logvar(tensor)
|
||||||
z = self.reparameterize(mu, logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
return mu, logvar, z
|
return mu, logvar, z
|
||||||
|
|
||||||
|
|
||||||
class Encoder(BaseEncoder):
|
class CNNEncoder(BaseCNNEncoder):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Encoder, self).__init__(*args, **kwargs)
|
super(CNNEncoder, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = super(Encoder, self).forward(x)
|
tensor = super(CNNEncoder, self).forward(x)
|
||||||
tensor = self.l1(tensor)
|
tensor = self.l1(tensor)
|
||||||
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -31,6 +31,8 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
activation=self.__getattribute__('activation')
|
activation=self.__getattribute__('activation')
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# Get rid of paramters that
|
||||||
|
paramter_mapping.__delitem__('in_shape')
|
||||||
|
|
||||||
return paramter_mapping
|
return paramter_mapping
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user