torchaudio testing
This commit is contained in:
104
datasets/base_dataset.py
Normal file
104
datasets/base_dataset.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from abc import ABC
|
||||
|
||||
import variables as V
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class BaseAudioToMelDataset(Dataset, ABC):
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self._task_type
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
return V.multi_classes
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return V.N_CLASS_binary if self.task_type == V.TASK_OPTION_binary else V.N_CLASS_multi
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
raise NotImplementedError
|
||||
return str(self._mel_transform)
|
||||
|
||||
# Data Structures
|
||||
@property
|
||||
def mel_folder(self):
|
||||
return self.data_root / 'mel'
|
||||
|
||||
@property
|
||||
def wav_folder(self):
|
||||
return self.data_root / self._wav_folder_name
|
||||
|
||||
@property
|
||||
def _container_ext(self):
|
||||
return '.mel'
|
||||
|
||||
def __init__(self, data_root: Union[str, Path], task_type, mel_kwargs,
|
||||
mel_augmentations=None, audio_augmentations=None, reset=False,
|
||||
wav_folder_name='wav', **_):
|
||||
super(BaseAudioToMelDataset, self).__init__()
|
||||
|
||||
# Dataset Parameters
|
||||
self.data_root = Path(data_root)
|
||||
self._wav_folder_name = wav_folder_name
|
||||
self.reset = reset
|
||||
self.mel_kwargs = mel_kwargs
|
||||
|
||||
# Transformations
|
||||
self.mel_augmentations = mel_augmentations
|
||||
self.audio_augmentations = audio_augmentations
|
||||
self._task_type = task_type
|
||||
|
||||
# Find all raw files and turn generator to persistent list:
|
||||
self._wav_files = list(self.wav_folder.rglob('*.wav'))
|
||||
|
||||
# Build the Dataset
|
||||
self._dataset = self._build_dataset()
|
||||
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, item):
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_dataset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _check_reset_and_clean_up(self, reset):
|
||||
all_mel_folders = set([str(x.parent).replace(self._wav_folder_name, 'mel') for x in self._wav_files])
|
||||
for mel_folder in all_mel_folders:
|
||||
param_storage = Path(mel_folder) / 'data_params.pik'
|
||||
param_storage.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
pik_data = param_storage.read_bytes()
|
||||
fingerprint = pickle.loads(pik_data)
|
||||
if fingerprint == self._fingerprint:
|
||||
this_reset = reset
|
||||
else:
|
||||
print('Diverging parameters were found; Refreshing...')
|
||||
param_storage.unlink()
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
this_reset = True
|
||||
|
||||
except FileNotFoundError:
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
this_reset = True
|
||||
|
||||
if this_reset:
|
||||
all_mel_files = self.mel_folder.rglob(f'*{self._container_ext}')
|
||||
for mel_file in all_mel_files:
|
||||
mel_file.unlink()
|
||||
|
||||
@@ -20,17 +20,20 @@ class BinaryMasksDataset(Dataset):
|
||||
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
return dict(**self._mel_kwargs, normalize=self.normalize)
|
||||
return dict(**self._mel_kwargs if self._mel_kwargs else dict())
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||
use_preprocessed=True):
|
||||
use_preprocessed=True, mel_kwargs=None):
|
||||
self.stretch = stretch_dataset
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
super(BinaryMasksDataset, self).__init__()
|
||||
|
||||
self.task = V.TASK_OPTION_binary
|
||||
|
||||
self.data_root = Path(data_root) / 'ComParE2020_Mask'
|
||||
self.setting = setting
|
||||
self._mel_kwargs = mel_kwargs
|
||||
self._wav_folder = self.data_root / 'wav'
|
||||
self._mel_folder = self.data_root / 'mel'
|
||||
self.container_ext = '.pik'
|
||||
|
||||
@@ -1,140 +1,78 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import multiprocessing as mp
|
||||
from typing import Union, List
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import multiprocessing as mp
|
||||
from torch.utils.data import ConcatDataset
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import variables as V
|
||||
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||
from ml_lib.modules.util import F_x
|
||||
from datasets.base_dataset import BaseAudioToMelDataset
|
||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset, PyTorchAudioToMelDataset
|
||||
|
||||
|
||||
class Urban8K(Dataset):
|
||||
try:
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
class Urban8K(BaseAudioToMelDataset):
|
||||
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
return str(self._mel_transform)
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
||||
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
|
||||
**_):
|
||||
def __init__(self,
|
||||
data_root, setting, fold: Union[int, List]=1, num_worker=mp.cpu_count(),
|
||||
reset=False, sample_segment_len=50, sample_hop_len=20,
|
||||
**kwargs):
|
||||
self.num_worker = num_worker
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
assert fold in range(1, 11)
|
||||
super(Urban8K, self).__init__()
|
||||
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
|
||||
|
||||
self.data_root = Path(data_root) / 'UrbanSound8K'
|
||||
#Dataset Paramters
|
||||
self.setting = setting
|
||||
self.num_worker = num_worker
|
||||
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
||||
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
||||
self.container_ext = '.pik'
|
||||
self._mel_transform = mel_transforms
|
||||
fold = fold if self.setting != V.DATA_OPTION_test else 10
|
||||
self.fold = fold if isinstance(fold, list) else [fold]
|
||||
|
||||
self._labels = self._build_labels()
|
||||
self._wav_files = list(sorted(self._labels.keys()))
|
||||
transforms = transforms or F_x(in_shape=None)
|
||||
self.sample_segment_len = sample_segment_len
|
||||
self.sample_hop_len = sample_hop_len
|
||||
|
||||
param_storage = self._mel_folder / 'data_params.pik'
|
||||
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
pik_data = param_storage.read_bytes()
|
||||
fingerprint = pickle.loads(pik_data)
|
||||
if fingerprint == self._fingerprint:
|
||||
self.use_preprocessed = use_preprocessed
|
||||
else:
|
||||
print('Diverging parameters were found; Refreshing...')
|
||||
param_storage.unlink()
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
# Dataset specific super init
|
||||
super(Urban8K, self).__init__(Path(data_root) / 'UrbanSound8K',
|
||||
V.TASK_OPTION_multiclass, reset=reset, wav_folder_name='audio', **kwargs
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
def _build_subdataset(self, row):
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
fold, class_id = (int(x) for x in (fold, class_id))
|
||||
if int(fold) in self.fold:
|
||||
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
|
||||
return PyTorchAudioToMelDataset(audio_file_path, class_id, **self.__dict__)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
while True:
|
||||
if not self.use_preprocessed:
|
||||
self._pre_process()
|
||||
try:
|
||||
self._dataset = ConcatDataset(
|
||||
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
||||
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
||||
label=self._labels[key]['label']
|
||||
) for key in self._labels.keys()]
|
||||
)
|
||||
break
|
||||
except IOError:
|
||||
self.use_preprocessed = False
|
||||
pass
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
def _build_dataset(self):
|
||||
dataset= list()
|
||||
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
for row in f:
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
if int(fold) == self.fold:
|
||||
key = slice_file_name.replace('.wav', '')
|
||||
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
||||
all_rows = list(f)
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
with tqdm(total=len(all_rows)) as pbar:
|
||||
for i, sub_dataset in enumerate(
|
||||
pool.imap_unordered(self._build_subdataset, all_rows, chunksize=chunksize)):
|
||||
pbar.update()
|
||||
dataset.append(sub_dataset)
|
||||
|
||||
# Delete File if one exists.
|
||||
if not self.use_preprocessed:
|
||||
for key in labeldict.keys():
|
||||
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
||||
try:
|
||||
mel_file.unlink(missing_ok=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return labeldict
|
||||
dataset = ConcatDataset([x for x in dataset if x is not None])
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def _pre_process(self):
|
||||
print('Preprocessing Mel Files....')
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
with tqdm(total=len(self._labels)) as pbar:
|
||||
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
||||
pbar.update()
|
||||
|
||||
def _build_mel(self, filename):
|
||||
|
||||
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
||||
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
||||
|
||||
if not mel_file:
|
||||
raw_sample, sr = librosa.core.load(wav_file)
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
m, n = mel_sample.shape
|
||||
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||
with mel_file.open(mode='wb') as f:
|
||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
# print(f"Already existed.. Skipping {filename}")
|
||||
mel_file = mel_file[0]
|
||||
|
||||
with mel_file.open(mode='rb') as f:
|
||||
mel_sample = pickle.load(f, fix_imports=True)
|
||||
return mel_sample, mel_file
|
||||
|
||||
def __getitem__(self, item):
|
||||
transformed_samples, label = self._dataset[item]
|
||||
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
label = torch.as_tensor(label, dtype=torch.int)
|
||||
|
||||
return transformed_samples, label
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import multiprocessing as mp
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import variables as V
|
||||
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
|
||||
from ml_lib.modules.util import F_x
|
||||
|
||||
|
||||
class Urban8K_TO(Dataset):
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
@property
|
||||
def _fingerprint(self):
|
||||
return str(self._mel_transform)
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
|
||||
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
|
||||
**_):
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
assert fold in range(1, 11)
|
||||
super(Urban8K_TO, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root) / 'UrbanSound8K'
|
||||
self.setting = setting
|
||||
self.num_worker = num_worker
|
||||
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
|
||||
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
|
||||
self.container_ext = '.pik'
|
||||
self._mel_transform = mel_transforms
|
||||
|
||||
self._labels = self._build_labels()
|
||||
self._wav_files = list(sorted(self._labels.keys()))
|
||||
transforms = transforms or F_x(in_shape=None)
|
||||
|
||||
param_storage = self._mel_folder / 'data_params.pik'
|
||||
self._mel_folder.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
pik_data = param_storage.read_bytes()
|
||||
fingerprint = pickle.loads(pik_data)
|
||||
if fingerprint == self._fingerprint:
|
||||
self.use_preprocessed = use_preprocessed
|
||||
else:
|
||||
print('Diverging parameters were found; Refreshing...')
|
||||
param_storage.unlink()
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
|
||||
except FileNotFoundError:
|
||||
pik_data = pickle.dumps(self._fingerprint)
|
||||
param_storage.write_bytes(pik_data)
|
||||
self.use_preprocessed = False
|
||||
|
||||
|
||||
while True:
|
||||
if not self.use_preprocessed:
|
||||
self._pre_process()
|
||||
try:
|
||||
self._dataset = ConcatDataset(
|
||||
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
|
||||
segment_len=audio_segment_len, hop_len=audio_hop_len,
|
||||
label=self._labels[key]['label']
|
||||
) for key in self._labels.keys()]
|
||||
)
|
||||
break
|
||||
except IOError:
|
||||
self.use_preprocessed = False
|
||||
pass
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
for row in f:
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
if int(fold) == self.fold:
|
||||
key = slice_file_name.replace('.wav', '')
|
||||
labeldict[key] = dict(label=int(class_id), fold=int(fold))
|
||||
|
||||
# Delete File if one exists.
|
||||
if not self.use_preprocessed:
|
||||
for key in labeldict.keys():
|
||||
for mel_file in self._mel_folder.rglob(f'{key}_*'):
|
||||
try:
|
||||
mel_file.unlink(missing_ok=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return labeldict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def _pre_process(self):
|
||||
print('Preprocessing Mel Files....')
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
with tqdm(total=len(self._labels)) as pbar:
|
||||
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
|
||||
pbar.update()
|
||||
|
||||
def _build_mel(self, filename):
|
||||
|
||||
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
|
||||
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
|
||||
|
||||
if not mel_file:
|
||||
raw_sample, sr = librosa.core.load(wav_file)
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
m, n = mel_sample.shape
|
||||
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
|
||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||
with mel_file.open(mode='wb') as f:
|
||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
# print(f"Already existed.. Skipping {filename}")
|
||||
mel_file = mel_file[0]
|
||||
|
||||
with mel_file.open(mode='rb') as f:
|
||||
mel_sample = pickle.load(f, fix_imports=True)
|
||||
return mel_sample, mel_file
|
||||
|
||||
def __getitem__(self, item):
|
||||
transformed_samples, label = self._dataset[item]
|
||||
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
|
||||
return transformed_samples, label
|
||||
Reference in New Issue
Block a user