Merge branch 'master' of gitlab.lrz.de:mobile-ifi/inter_challenge_2020

This commit is contained in:
Steffen Illium 2020-05-15 11:13:43 +02:00
commit bba1f74f78
4 changed files with 17 additions and 8 deletions

View File

@ -22,13 +22,13 @@ main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="") main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
# Transformation Parameters # Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4

View File

@ -19,7 +19,9 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self): def sample_shape(self):
return self[0][0].shape return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False): def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.' assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.' assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
@ -28,12 +30,13 @@ class BinaryMasksDataset(Dataset):
self.data_root = Path(data_root) self.data_root = Path(data_root)
self.setting = setting self.setting = setting
self.mixup = mixup self.mixup = mixup
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik' self.container_ext = '.pik'
self._mel_transform = mel_transforms self._mel_transform = mel_transforms
self._labels = self._build_labels() self._labels = self._build_labels()
self._wav_folder = self.data_root / 'wav'
self._wav_files = list(sorted(self._labels.keys())) self._wav_files = list(sorted(self._labels.keys()))
self._mel_folder = self.data_root / 'mel'
self._transforms = transforms or F_x(in_shape=None) self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self): def _build_labels(self):
@ -51,18 +54,25 @@ class BinaryMasksDataset(Dataset):
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()}) additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()})
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()}) additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict) labeldict.update(additional_dict)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink(missing_ok=True)
return labeldict return labeldict
def __len__(self): def __len__(self):
return len(self._labels) * 2 if self.mixup else len(self._labels) return len(self._labels) * 2 if self.mixup else len(self._labels)
def _compute_or_retrieve(self, filename): def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists(): if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav')) raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample) mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True) self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f: with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True) mel_sample = pickle.load(f, fix_imports=True)
return mel_sample return mel_sample

View File

@ -41,13 +41,11 @@ class ResidualConvClassifier(BinaryMaskDatasetMixin,
last_shape = self.conv_list[-1].shape last_shape = self.conv_list[-1].shape
for idx in range(len(self.conv_filters)): for idx in range(len(self.conv_filters)):
conv_module_params.update(conv_filters=self.conv_filters[idx]) conv_module_params.update(conv_filters=self.conv_filters[idx])
self.conv_list.append(ResidualModule(last_shape, ConvModule, 3, **conv_module_params)) self.conv_list.append(ResidualModule(last_shape, ConvModule, 2, **conv_module_params))
last_shape = self.conv_list[-1].shape last_shape = self.conv_list[-1].shape
try: try:
self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2, self.conv_list.append(ConvModule(last_shape, self.conv_filters[idx+1], (k, k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs)) **self.params.module_kwargs))
for param in self.conv_list[-1].parameters():
param.requires_grad = False
last_shape = self.conv_list[-1].shape last_shape = self.conv_list[-1].shape
except IndexError: except IndexError:
pass pass

View File

@ -142,6 +142,7 @@ class BinaryMaskDatasetMixin:
**dict( **dict(
# TRAIN DATASET # TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
use_preprocessed=self.params.use_preprocessed,
mixup=self.params.mixup, stretch_dataset=self.params.stretch, mixup=self.params.mixup, stretch_dataset=self.params.stretch,
mel_transforms=mel_transforms_train, transforms=aug_transforms), mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET # VALIDATION DATASET