ml_lib/audio_toolset/mel_dataset.py
Steffen Illium b5e3e5aec1 Dataset rdy
2021-02-16 10:18:03 +01:00

53 lines
2.0 KiB
Python

import time
from pathlib import Path
import pickle
from torch.utils.data import Dataset
from ml_lib.modules.util import AutoPadToShape
class TorchMelDataset(Dataset):
def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len,
sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True):
super(TorchMelDataset, self).__init__()
self.sampling_rate = int(sampling_rate)
self.audio_file_len = float(audio_file_len)
if auto_pad_to_shape and sub_segment_len:
self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len)))
else:
self.padding = None
self.path = Path(mel_path)
self.sub_segment_len = int(sub_segment_len)
self.mel_hop_len = int(mel_hop_len)
self.sub_segment_hop_len = int(sub_segment_hop_len)
self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1)
if self.sub_segment_len and self.sub_segment_hop_len:
self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len))
else:
self.offsets = [0]
self.label = label
self.transform = transform
def __getitem__(self, item):
with self.path.open('rb') as mel_file:
mel_spec = pickle.load(mel_file, fix_imports=True)
start = self.offsets[item]
sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len
sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1]
if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length:
duration = self.sub_segment_len
else:
duration = mel_spec.shape[1]
snippet = mel_spec[:, start: start + duration]
if self.transform:
snippet = self.transform(snippet)
if self.padding:
snippet = self.padding(snippet)
return self.path.__str__(), snippet, self.label
def __len__(self):
return len(self.offsets)