Merge remote-tracking branch 'origin/master'
# Conflicts: # multi_run.py
This commit is contained in:
commit
95aa7c4cc5
@ -21,18 +21,16 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
|
|||||||
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
|
main_arg_parser.add_argument("--data_worker", type=int, default=11, help="")
|
||||||
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
|
||||||
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
|
main_arg_parser.add_argument("--data_class_name", type=str, default='BinaryMasksDataset', help="")
|
||||||
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
|
||||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
main_arg_parser.add_argument("--data_n_mels", type=int, default=64, help="")
|
||||||
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
main_arg_parser.add_argument("--data_sr", type=int, default=16000, help="")
|
||||||
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
main_arg_parser.add_argument("--data_hop_length", type=int, default=256, help="")
|
||||||
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
||||||
main_arg_parser.add_argument("--data_mixup", type=strtobool, default=False, help="")
|
|
||||||
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
# Transformation Parameters
|
# Transformation Parameters
|
||||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
|
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
|
||||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
||||||
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
|
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3
|
||||||
@ -54,7 +52,7 @@ main_arg_parser.add_argument("--train_outpath", type=str, default="output", help
|
|||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
||||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-8, help="")
|
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
|
||||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
|
||||||
|
|
||||||
import librosa as librosa
|
import librosa as librosa
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
@ -19,7 +18,7 @@ class BinaryMasksDataset(Dataset):
|
|||||||
def sample_shape(self):
|
def sample_shape(self):
|
||||||
return self[0][0].shape
|
return self[0][0].shape
|
||||||
|
|
||||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, mixup=False, stretch_dataset=False,
|
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||||
use_preprocessed=True):
|
use_preprocessed=True):
|
||||||
self.use_preprocessed = use_preprocessed
|
self.use_preprocessed = use_preprocessed
|
||||||
self.stretch = stretch_dataset
|
self.stretch = stretch_dataset
|
||||||
@ -29,7 +28,6 @@ class BinaryMasksDataset(Dataset):
|
|||||||
|
|
||||||
self.data_root = Path(data_root)
|
self.data_root = Path(data_root)
|
||||||
self.setting = setting
|
self.setting = setting
|
||||||
self.mixup = mixup
|
|
||||||
self._wav_folder = self.data_root / 'wav'
|
self._wav_folder = self.data_root / 'wav'
|
||||||
self._mel_folder = self.data_root / 'mel'
|
self._mel_folder = self.data_root / 'mel'
|
||||||
self.container_ext = '.pik'
|
self.container_ext = '.pik'
|
||||||
@ -40,19 +38,20 @@ class BinaryMasksDataset(Dataset):
|
|||||||
self._transforms = transforms or F_x(in_shape=None)
|
self._transforms = transforms or F_x(in_shape=None)
|
||||||
|
|
||||||
def _build_labels(self):
|
def _build_labels(self):
|
||||||
|
labeldict = dict()
|
||||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||||
# Exclude the header
|
# Exclude the header
|
||||||
_ = next(f)
|
_ = next(f)
|
||||||
labeldict = dict()
|
|
||||||
for row in f:
|
for row in f:
|
||||||
if self.setting not in row:
|
if self.setting not in row:
|
||||||
continue
|
continue
|
||||||
filename, label = row.strip().split(',')
|
filename, label = row.strip().split(',')
|
||||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
||||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||||
additional_dict = ({f'X_{key}': val for key, val in labeldict.items()})
|
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||||
additional_dict.update({f'X_X_{key}': val for key, val in labeldict.items()})
|
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||||
additional_dict.update({f'X_X_X_{key}': val for key, val in labeldict.items()})
|
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
|
||||||
|
additional_dict.update({f'XXXX{key}': val for key, val in labeldict.items()})
|
||||||
labeldict.update(additional_dict)
|
labeldict.update(additional_dict)
|
||||||
|
|
||||||
# Delete File if one exists.
|
# Delete File if one exists.
|
||||||
@ -66,12 +65,12 @@ class BinaryMasksDataset(Dataset):
|
|||||||
return labeldict
|
return labeldict
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._labels) * 2 if self.mixup else len(self._labels)
|
return len(self._labels)
|
||||||
|
|
||||||
def _compute_or_retrieve(self, filename):
|
def _compute_or_retrieve(self, filename):
|
||||||
|
|
||||||
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
||||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X_', '') + '.wav'))
|
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
|
||||||
mel_sample = self._mel_transform(raw_sample)
|
mel_sample = self._mel_transform(raw_sample)
|
||||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
||||||
@ -82,28 +81,16 @@ class BinaryMasksDataset(Dataset):
|
|||||||
return mel_sample
|
return mel_sample
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
is_mixed = item >= len(self._labels)
|
|
||||||
if is_mixed:
|
|
||||||
item = item - len(self._labels)
|
|
||||||
|
|
||||||
key: str = list(self._labels.keys())[item]
|
key: str = list(self._labels.keys())[item]
|
||||||
filename = key.replace('.wav', '')
|
filename = key.replace('.wav', '')
|
||||||
mel_sample = self._compute_or_retrieve(filename)
|
mel_sample = self._compute_or_retrieve(filename)
|
||||||
label = self._labels[key]
|
label = self._labels[key]
|
||||||
|
|
||||||
if is_mixed:
|
|
||||||
label_sec = -1
|
|
||||||
while label_sec != self._labels[key]:
|
|
||||||
key_sec = random.choice(list(self._labels.keys()))
|
|
||||||
label_sec = self._labels[key_sec]
|
|
||||||
# noinspection PyUnboundLocalVariable
|
|
||||||
filename_sec = key_sec[:-4]
|
|
||||||
mel_sample_sec = self._compute_or_retrieve(filename_sec)
|
|
||||||
mix_in_border = int(random.random() * mel_sample.shape[-1]) * random.choice([1, -1])
|
|
||||||
mel_sample[:, :mix_in_border] = mel_sample_sec[:, :mix_in_border]
|
|
||||||
|
|
||||||
transformed_samples = self._transforms(mel_sample)
|
transformed_samples = self._transforms(mel_sample)
|
||||||
if not self.setting == 'test':
|
|
||||||
|
if self.setting != V.DATA_OPTIONS.test:
|
||||||
|
# In test, filenames instead of labels are returned. This is a little hacky though.
|
||||||
label = torch.as_tensor(label, dtype=torch.float)
|
label = torch.as_tensor(label, dtype=torch.float)
|
||||||
|
|
||||||
return transformed_samples, label
|
return transformed_samples, label
|
||||||
|
11
main.py
11
main.py
@ -110,6 +110,7 @@ def run_lightning_loop(config_obj):
|
|||||||
inference_out = f'{parameters}_test_out.csv'
|
inference_out = f'{parameters}_test_out.csv'
|
||||||
|
|
||||||
from main_inference import prepare_dataloader
|
from main_inference import prepare_dataloader
|
||||||
|
import variables as V
|
||||||
test_dataloader = prepare_dataloader(config_obj)
|
test_dataloader = prepare_dataloader(config_obj)
|
||||||
|
|
||||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||||
@ -118,12 +119,12 @@ def run_lightning_loop(config_obj):
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
||||||
batch_x, file_name = batch
|
batch_x, file_name = batch
|
||||||
batch_x = batch_x.unsqueeze(0).to(device='cuda' if model.on_gpu else 'cpu')
|
batch_x = batch_x.to(device='cuda' if model.on_gpu else 'cpu')
|
||||||
y = model(batch_x).main_out
|
y = model(batch_x).main_out
|
||||||
prediction = (y.squeeze() >= 0.5).int().item()
|
predictions = (y >= 0.5).int()
|
||||||
import variables as V
|
for prediction in predictions:
|
||||||
prediction = 'clear' if prediction == V.CLEAR else 'mask'
|
prediction_text = 'clear' if prediction == V.CLEAR else 'mask'
|
||||||
outfile.write(f'{file_name},{prediction}\n')
|
outfile.write(f'{file_name},{prediction_text}\n')
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,7 +43,8 @@ def prepare_dataloader(config_obj):
|
|||||||
mel_transforms=mel_transforms, transforms=transforms
|
mel_transforms=mel_transforms, transforms=transforms
|
||||||
)
|
)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)
|
return DataLoader(dataset, batch_size=config_obj.train.batch_size,
|
||||||
|
num_workers=config_obj.data.worker, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
def restore_logger_and_model(log_dir):
|
def restore_logger_and_model(log_dir):
|
||||||
|
18
multi_run.py
18
multi_run.py
@ -20,31 +20,31 @@ if __name__ == '__main__':
|
|||||||
config = MConfig().read_namespace(args)
|
config = MConfig().read_namespace(args)
|
||||||
|
|
||||||
arg_dict = dict()
|
arg_dict = dict()
|
||||||
for seed in range(40, 45):
|
for seed in range(0, 10):
|
||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
|
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
|
||||||
arg_dict.update(model_type=model)
|
arg_dict.update(model_type=model)
|
||||||
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
data_stretch=False)
|
data_stretch=False, train_epochs=401)
|
||||||
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.3, data_mask_ratio=0.2,
|
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2,
|
||||||
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
||||||
data_stretch=True)
|
data_stretch=True, train_epochs=101)
|
||||||
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
|
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
|
||||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
data_stretch=True)
|
data_stretch=True, train_epochs=101)
|
||||||
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
|
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
|
||||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
data_stretch=True)
|
data_stretch=True, train_epochs=101)
|
||||||
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
data_stretch=True)
|
data_stretch=True, train_epochs=101)
|
||||||
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0,
|
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0,
|
||||||
data_stretch=True)
|
data_stretch=True, train_epochs=101)
|
||||||
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
||||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
|
||||||
data_stretch=True)
|
data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
|
||||||
|
|
||||||
|
@ -122,7 +122,8 @@ class BinaryMaskDatasetMixin:
|
|||||||
mel_transforms = Compose([
|
mel_transforms = Compose([
|
||||||
# Audio to Mel Transformations
|
# Audio to Mel Transformations
|
||||||
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
|
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
|
||||||
hop_length=self.params.hop_length), MelToImage()])
|
hop_length=self.params.hop_length),
|
||||||
|
MelToImage()])
|
||||||
# Data Augmentations
|
# Data Augmentations
|
||||||
aug_transforms = Compose([
|
aug_transforms = Compose([
|
||||||
RandomApply([
|
RandomApply([
|
||||||
@ -132,7 +133,8 @@ class BinaryMaskDatasetMixin:
|
|||||||
MaskAug(self.params.mask_ratio),
|
MaskAug(self.params.mask_ratio),
|
||||||
], p=0.6),
|
], p=0.6),
|
||||||
# Utility
|
# Utility
|
||||||
NormalizeLocal(), ToTensor()
|
NormalizeLocal(),
|
||||||
|
ToTensor()
|
||||||
])
|
])
|
||||||
val_transforms = Compose([NormalizeLocal(), ToTensor()])
|
val_transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||||
|
|
||||||
@ -143,7 +145,7 @@ class BinaryMaskDatasetMixin:
|
|||||||
# TRAIN DATASET
|
# TRAIN DATASET
|
||||||
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||||
use_preprocessed=self.params.use_preprocessed,
|
use_preprocessed=self.params.use_preprocessed,
|
||||||
mixup=self.params.mixup, stretch_dataset=self.params.stretch,
|
stretch_dataset=self.params.stretch,
|
||||||
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
mel_transforms=mel_transforms_train, transforms=aug_transforms),
|
||||||
# VALIDATION DATASET
|
# VALIDATION DATASET
|
||||||
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user