transition

This commit is contained in:
Si11ium
2021-02-01 10:23:22 +01:00
parent f6156c6cde
commit 010176e80b
18 changed files with 133 additions and 61 deletions

View File

@ -1,5 +1,3 @@
from typing import Union
import numpy as np
try:

View File

@ -20,7 +20,7 @@ class _AudioToMelDataset(Dataset, ABC):
def sampling_rate(self):
raise NotImplementedError
def __init__(self, audio_file_path, label, sample_segment_len=1, sample_hop_len=1, reset=False,
def __init__(self, audio_file_path, label, sample_segment_len=0, sample_hop_len=0, reset=False,
audio_augmentations=None, mel_augmentations=None, mel_kwargs=None, **kwargs):
self.ignored_kwargs = kwargs
self.mel_kwargs = mel_kwargs
@ -46,7 +46,7 @@ class _AudioToMelDataset(Dataset, ABC):
return self.dataset[item]
except FileNotFoundError:
assert self._build_mel()
return self.dataset[item]
return self.dataset[item]
def __len__(self):
return len(self.dataset)
@ -79,7 +79,6 @@ class LibrosaAudioToMelDataset(_AudioToMelDataset):
MelToImage()
])
def _build_mel(self):
if self.reset:
self.mel_file_path.unlink(missing_ok=True)

View File

@ -13,13 +13,16 @@ class TorchMelDataset(Dataset):
super(TorchMelDataset, self).__init__()
self.sampling_rate = sampling_rate
self.audio_file_len = audio_file_len
self.padding = AutoPadToShape((n_mels , sub_segment_len)) if auto_pad_to_shape else None
self.padding = AutoPadToShape((n_mels, sub_segment_len)) if auto_pad_to_shape and sub_segment_len else None
self.path = Path(mel_path)
self.sub_segment_len = sub_segment_len
self.mel_hop_len = mel_hop_len
self.sub_segment_hop_len = sub_segment_hop_len
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
if self.sub_segment_len and self.sub_segment_hop_len:
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
else:
self.offsets = [0]
self.label = label
self.transform = transform
@ -29,7 +32,8 @@ class TorchMelDataset(Dataset):
with self.path.open('rb') as mel_file:
mel_spec = pickle.load(mel_file, fix_imports=True)
start = self.offsets[item]
snippet = mel_spec[: , start: start + self.sub_segment_len]
duration = self.sub_segment_len if self.sub_segment_len and self.sub_segment_hop_len else mel_spec.shape[1]
snippet = mel_spec[:, start: start + duration]
if self.transform:
snippet = self.transform(snippet)
if self.padding: