Audio Dataset

This commit is contained in:
Si11ium
2020-12-01 16:37:16 +01:00
parent 95561acc35
commit 95dcf22f3d
15 changed files with 468 additions and 145 deletions

View File

@@ -18,15 +18,18 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return dict(**self._mel_kwargs, normalize=self.normalize)
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
self.data_root = Path(data_root) / 'ComParE2020_Mask'
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
@@ -37,16 +40,36 @@ class BinaryMasksDataset(Dataset):
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = True
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = True
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
labelfile = 'labels' if self.setting != V.DATA_OPTIONS.test else V.DATA_OPTIONS.test
with open(Path(self.data_root) / 'lab' / f'{labelfile}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
labeldict[filename] = self._to_label[label.lower()] # if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})

View File

@@ -1,95 +1,140 @@
import pickle
from collections import defaultdict
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1)
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
class Urban8K(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=62, audio_hop_len=30, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
assert fold in range(1, 11)
super(Urban8K, self).__init__()
self.data_root = Path(data_root)
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict)
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
try:
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
except FileNotFoundError:
pass
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._labels)
return len(self._dataset)
def _compute_or_retrieve(self, filename):
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
transformed_samples = self._transforms(mel_sample)
if self.setting != V.DATA_OPTIONS.test:
# In test, filenames instead of labels are returned. This is a little hacky though.
label = torch.as_tensor(label, dtype=torch.float)
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

View File

@@ -0,0 +1,140 @@
import pickle
from pathlib import Path
import multiprocessing as mp
import librosa as librosa
from torch.utils.data import Dataset, ConcatDataset
import torch
from tqdm import tqdm
import variables as V
from ml_lib.audio_toolset.mel_dataset import TorchMelDataset
from ml_lib.modules.util import F_x
class Urban8K_TO(Dataset):
@property
def sample_shape(self):
return self[0][0].shape
@property
def _fingerprint(self):
return str(self._mel_transform)
def __init__(self, data_root, setting, mel_transforms, fold=1, transforms=None,
use_preprocessed=True, audio_segment_len=1, audio_hop_len=1, num_worker=mp.cpu_count(),
**_):
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
assert fold in range(1, 11)
super(Urban8K_TO, self).__init__()
self.data_root = Path(data_root) / 'UrbanSound8K'
self.setting = setting
self.num_worker = num_worker
self.fold = fold if self.setting == V.DATA_OPTIONS.train else 10
self.use_preprocessed = use_preprocessed
self._wav_folder = self.data_root / 'audio' / f'fold{self.fold}'
self._mel_folder = self.data_root / 'mel' / f'fold{self.fold}'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
transforms = transforms or F_x(in_shape=None)
param_storage = self._mel_folder / 'data_params.pik'
self._mel_folder.mkdir(parents=True, exist_ok=True)
try:
pik_data = param_storage.read_bytes()
fingerprint = pickle.loads(pik_data)
if fingerprint == self._fingerprint:
self.use_preprocessed = use_preprocessed
else:
print('Diverging parameters were found; Refreshing...')
param_storage.unlink()
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
except FileNotFoundError:
pik_data = pickle.dumps(self._fingerprint)
param_storage.write_bytes(pik_data)
self.use_preprocessed = False
while True:
if not self.use_preprocessed:
self._pre_process()
try:
self._dataset = ConcatDataset(
[TorchMelDataset(identifier=key, mel_path=self._mel_folder, transform=transforms,
segment_len=audio_segment_len, hop_len=audio_hop_len,
label=self._labels[key]['label']
) for key in self._labels.keys()]
)
break
except IOError:
self.use_preprocessed = False
pass
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
if int(fold) == self.fold:
key = slice_file_name.replace('.wav', '')
labeldict[key] = dict(label=int(class_id), fold=int(fold))
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
for mel_file in self._mel_folder.rglob(f'{key}_*'):
try:
mel_file.unlink(missing_ok=True)
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._dataset)
def _pre_process(self):
print('Preprocessing Mel Files....')
with mp.Pool(processes=self.num_worker) as pool:
with tqdm(total=len(self._labels)) as pbar:
for i, _ in enumerate(pool.imap_unordered(self._build_mel, self._labels.keys())):
pbar.update()
def _build_mel(self, filename):
wav_file = self._wav_folder / (filename.replace('X', '') + '.wav')
mel_file = list(self._mel_folder.glob(f'{filename}_*'))
if not mel_file:
raw_sample, sr = librosa.core.load(wav_file)
mel_sample = self._mel_transform(raw_sample)
m, n = mel_sample.shape
mel_file = self._mel_folder / f'{filename}_{m}_{n}'
self._mel_folder.mkdir(exist_ok=True, parents=True)
with mel_file.open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
# print(f"Already existed.. Skipping {filename}")
mel_file = mel_file[0]
with mel_file.open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample, mel_file
def __getitem__(self, item):
transformed_samples, label = self._dataset[item]
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label