Dataset rdy

This commit is contained in:
Steffen Illium 2021-02-16 10:18:03 +01:00
parent a966321576
commit b5e3e5aec1
3 changed files with 21 additions and 12 deletions

View File

@ -12,7 +12,7 @@ class TorchMelDataset(Dataset):
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
super(TorchMelDataset, self).__init__()
self.sampling_rate = int(sampling_rate)
self.audio_file_len = int(audio_file_len)
self.audio_file_len = float(audio_file_len)
if auto_pad_to_shape and sub_segment_len:
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
else:
@ -33,7 +33,14 @@ class TorchMelDataset(Dataset):
with self.path.open('rb') as mel_file:
mel_spec = pickle.load(mel_file, fix_imports=True)
start = self.offsets[item]
duration = self.sub_segment_len if self.sub_segment_len and self.sub_segment_hop_len else mel_spec.shape[1]
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
duration = self.sub_segment_len
else:
duration = mel_spec.shape[1]
snippet = mel_spec[:, start: start + duration]
if self.transform:
snippet = self.transform(snippet)

View File

@ -178,13 +178,13 @@ class UnitGenerator(Generator):
return tensor
class BaseEncoder(ShapeMixin, nn.Module):
class BaseCNNEncoder(ShapeMixin, nn.Module):
# noinspection PyUnresolvedReferences
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
filters: List[int] = None, kernels: List[int] = None, **kwargs):
super(BaseEncoder, self).__init__()
super(BaseCNNEncoder, self).__init__()
assert filters, '"Filters" has to be a list of int'
assert kernels, '"Kernels" has to be a list of int'
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
@ -227,11 +227,11 @@ class BaseEncoder(ShapeMixin, nn.Module):
return tensor
class UnitEncoder(BaseEncoder):
class UnitCNNEncoder(BaseCNNEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
kwargs.update(use_norm=True)
super(UnitEncoder, self).__init__(*args, **kwargs)
super(UnitCNNEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
def forward(self, x):
@ -243,10 +243,10 @@ class UnitEncoder(BaseEncoder):
return c1, c2, c3, l1
class VariationalEncoder(BaseEncoder):
class VariationalCNNEncoder(BaseCNNEncoder):
# noinspection PyUnresolvedReferences
def __init__(self, *args, **kwargs):
super(VariationalEncoder, self).__init__(*args, **kwargs)
super(VariationalCNNEncoder, self).__init__(*args, **kwargs)
self.logvar = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
self.mu = nn.Linear(reduce(mul, self.conv3.shape), self.lat_dim, bias=self.use_bias)
@ -258,22 +258,22 @@ class VariationalEncoder(BaseEncoder):
return mu + eps*std
def forward(self, x):
tensor = super(VariationalEncoder, self).forward(x)
tensor = super(VariationalCNNEncoder, self).forward(x)
mu = self.mu(tensor)
logvar = self.logvar(tensor)
z = self.reparameterize(mu, logvar)
return mu, logvar, z
class Encoder(BaseEncoder):
class CNNEncoder(BaseCNNEncoder):
def __init__(self, *args, **kwargs):
super(Encoder, self).__init__(*args, **kwargs)
super(CNNEncoder, self).__init__(*args, **kwargs)
self.l1 = nn.Linear(self.flat.shape, self.lat_dim, bias=self.use_bias)
def forward(self, x):
tensor = super(Encoder, self).forward(x)
tensor = super(CNNEncoder, self).forward(x)
tensor = self.l1(tensor)
tensor = self.latent_activation(tensor) if self.latent_activation else tensor
return tensor

View File

@ -31,6 +31,8 @@ class ModelParameters(Namespace, Mapping):
activation=self.__getattribute__('activation')
)
)
# Get rid of paramters that
paramter_mapping.__delitem__('in_shape')
return paramter_mapping