import time from pathlib import Path import pickle from torch.utils.data import Dataset from ml_lib.modules.util import AutoPadToShape class TorchMelDataset(Dataset): def __init__(self, mel_path, sub_segment_len, sub_segment_hop_len, label, audio_file_len, sampling_rate, mel_hop_len, n_mels, transform=None, auto_pad_to_shape=True): super(TorchMelDataset, self).__init__() self.sampling_rate = int(sampling_rate) self.audio_file_len = float(audio_file_len) if auto_pad_to_shape and sub_segment_len: self.padding = AutoPadToShape((int(n_mels), int(sub_segment_len))) else: self.padding = None self.path = Path(mel_path) self.sub_segment_len = int(sub_segment_len) self.mel_hop_len = int(mel_hop_len) self.sub_segment_hop_len = int(sub_segment_hop_len) self.n = int((self.sampling_rate / self.mel_hop_len) * self.audio_file_len + 1) if self.sub_segment_len and self.sub_segment_hop_len and (self.n - self.sub_segment_len) > 0: self.offsets = list(range(0, self.n - self.sub_segment_len, self.sub_segment_hop_len)) else: self.offsets = [0] if len(self) == 0: print('what happend here') self.label = label self.transform = transform def __getitem__(self, item): with self.path.open('rb') as mel_file: mel_spec = pickle.load(mel_file, fix_imports=True) start = self.offsets[item] sub_segments_attributes_set = self.sub_segment_len and self.sub_segment_hop_len sub_segment_length_smaller_then_tot_length = self.sub_segment_len < mel_spec.shape[1] if sub_segments_attributes_set and sub_segment_length_smaller_then_tot_length: duration = self.sub_segment_len else: duration = mel_spec.shape[1] snippet = mel_spec[:, start: start + duration] if self.transform: snippet = self.transform(snippet) if self.padding: snippet = self.padding(snippet) return self.path.__str__(), snippet, self.label def __len__(self): return len(self.offsets)