Merge remote-tracking branch 'origin/master'

# Conflicts:
#	multi_run.py
This commit is contained in:
Steffen Illium 2020-05-21 12:16:45 +02:00
commit 95aa7c4cc5
6 changed files with 35 additions and 46 deletions

View File

@ -21,18 +21,16 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="") main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="") main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="") main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="") main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="") main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="") main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
# Transformation Parameters # Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4 main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3 main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
@ -54,7 +52,7 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation? # FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="") main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="") main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="") main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")

View File

@ -1,7 +1,6 @@
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import random
import librosa as librosa import librosa as librosa
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -19,7 +18,7 @@ class BinaryMasksDataset(Dataset):
def sample_shape(self): def sample_shape(self):
return self[0][0].shape return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False, def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True): use_preprocessed=True):
self.use_preprocessed = use_preprocessed self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset self.stretch = stretch_dataset
@ -29,7 +28,6 @@ class BinaryMasksDataset(Dataset):
self.data_root = Path(data_root) self.data_root = Path(data_root)
self.setting = setting self.setting = setting
self.mixup = mixup
self._wav_folder = self.data_root / 'wav' self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel' self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik' self.container_ext = '.pik'
@ -40,19 +38,20 @@ class BinaryMasksDataset(Dataset):
self._transforms = transforms or F_x(in_shape=None) self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self): def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f: with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
# Exclude the header # Exclude the header
_ = next(f) _ = next(f)
labeldict = dict()
for row in f: for row in f:
if self.setting not in row: if self.setting not in row:
continue continue
filename, label = row.strip().split(',') filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train: if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X_{key}': val for key, val in labeldict.items()}) additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()}) additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()}) additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XXXX{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict) labeldict.update(additional_dict)
# Delete File if one exists. # Delete File if one exists.
@ -66,12 +65,12 @@ class BinaryMasksDataset(Dataset):
return labeldict return labeldict
def __len__(self): def __len__(self):
return len(self._labels) * 2 if self.mixup else len(self._labels) return len(self._labels)
def _compute_or_retrieve(self, filename): def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists(): if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav')) raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample) mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True) self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f: with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
@ -82,28 +81,16 @@ class BinaryMasksDataset(Dataset):
return mel_sample return mel_sample
def __getitem__(self, item): def __getitem__(self, item):
is_mixed = item >= len(self._labels)
if is_mixed:
item = item - len(self._labels)
key: str = list(self._labels.keys())[item] key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '') filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename) mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key] label = self._labels[key]
if is_mixed:
label_sec = -1
while label_sec != self._labels[key]:
key_sec = random.choice(list(self._labels.keys()))
label_sec = self._labels[key_sec]
# noinspection PyUnboundLocalVariable
filename_sec = key_sec[:-4]
mel_sample_sec = self._compute_or_retrieve(filename_sec)
mix_in_border = int(random.random() * mel_sample.shape[-1]) * random.choice([1, -1])
mel_sample[:, :mix_in_border] = mel_sample_sec[:, :mix_in_border]
transformed_samples = self._transforms(mel_sample) transformed_samples = self._transforms(mel_sample)
if not self.setting == 'test':
if self.setting != V.DATA_OPTIONS.test:
# In test, filenames instead of labels are returned. This is a little hacky though.
label = torch.as_tensor(label, dtype=torch.float) label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label return transformed_samples, label

11
main.py
View File

@ -110,6 +110,7 @@ def run_lightning_loop(config_obj):
inference_out = f'{parameters}_test_out.csv' inference_out = f'{parameters}_test_out.csv'
from main_inference import prepare_dataloader from main_inference import prepare_dataloader
import variables as V
test_dataloader = prepare_dataloader(config_obj) test_dataloader = prepare_dataloader(config_obj)
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile: with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
@ -118,12 +119,12 @@ def run_lightning_loop(config_obj):
from tqdm import tqdm from tqdm import tqdm
for batch in tqdm(test_dataloader, total=len(test_dataloader)): for batch in tqdm(test_dataloader, total=len(test_dataloader)):
batch_x, file_name = batch batch_x, file_name = batch
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu') batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
y = model(batch_x).main_out y = model(batch_x).main_out
prediction = (y.squeeze() >= 0.5).int().item() predictions = (y >= 0.5).int()
import variables as V for prediction in predictions:
prediction = 'clear' if prediction == V.CLEAR else 'mask' prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
outfile.write(f'{file_name},{prediction}\n') outfile.write(f'{file_name},{prediction_text}\n')
return model return model

View File

@ -43,7 +43,8 @@ def prepare_dataloader(config_obj):
mel_transforms=mel_transforms, transforms=transforms mel_transforms=mel_transforms, transforms=transforms
) )
# noinspection PyTypeChecker # noinspection PyTypeChecker
return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False) return DataLoader(dataset, batch_size=config_obj.train.batch_size,
num_workers=config_obj.data.worker, shuffle=False)
def restore_logger_and_model(log_dir): def restore_logger_and_model(log_dir):

View File

@ -20,31 +20,31 @@ if __name__ == '__main__':
config = MConfig().read_namespace(args) config = MConfig().read_namespace(args)
arg_dict = dict() arg_dict = dict()
for seed in range(40, 45): for seed in range(0, 10):
arg_dict.update(main_seed=seed) arg_dict.update(main_seed=seed)
for model in ['CC', 'BCMC', 'BCC', 'RCC']: for model in ['CC', 'BCMC', 'BCC', 'RCC']:
arg_dict.update(model_type=model) arg_dict.update(model_type=model)
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=False) data_stretch=False, train_epochs=401)
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.3, data_mask_ratio=0.2, all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2,
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
data_stretch=True) data_stretch=True, train_epochs=101)
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0, speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=True) data_stretch=True, train_epochs=101)
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2, mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=True) data_stretch=True, train_epochs=101)
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=True) data_stretch=True, train_epochs=101)
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0,
data_stretch=True) data_stretch=True, train_epochs=101)
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
data_stretch=True) data_stretch=True, train_epochs=101)
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]: for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:

View File

@ -122,7 +122,8 @@ class BinaryMaskDatasetMixin:
mel_transforms = Compose([ mel_transforms = Compose([
# Audio to Mel Transformations # Audio to Mel Transformations
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length), MelToImage()]) hop_length=self.params.hop_length),
MelToImage()])
# Data Augmentations # Data Augmentations
aug_transforms = Compose([ aug_transforms = Compose([
RandomApply([ RandomApply([
@ -132,7 +133,8 @@ class BinaryMaskDatasetMixin:
MaskAug(self.params.mask_ratio), MaskAug(self.params.mask_ratio),
], p=0.6), ], p=0.6),
# Utility # Utility
NormalizeLocal(), ToTensor() NormalizeLocal(),
ToTensor()
]) ])
val_transforms = Compose([NormalizeLocal(), ToTensor()]) val_transforms = Compose([NormalizeLocal(), ToTensor()])
@ -143,7 +145,7 @@ class BinaryMaskDatasetMixin:
# TRAIN DATASET # TRAIN DATASET
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
use_preprocessed=self.params.use_preprocessed, use_preprocessed=self.params.use_preprocessed,
mixup=self.params.mixup, stretch_dataset=self.params.stretch, stretch_dataset=self.params.stretch,
mel_transforms=mel_transforms_train, transforms=aug_transforms), mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET # VALIDATION DATASET
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,