requirements.txt updated @torch1.4

speed augmentation updated
paramters updated
This commit is contained in:
Steffen Illium 2020-05-21 14:12:54 +02:00
parent b529d130df
commit d58bcbf14b
2 changed files with 22 additions and 17 deletions

View File

@ -33,8 +33,9 @@ main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, hel
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3 main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.3
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4 main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2 main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_speed_ratio", type=float, default=0, help="") # 0.3 main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.3
main_arg_parser.add_argument("--data_speed_factor", type=float, default=0, help="") # 0.7 main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
# Model Parameters # Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="") main_arg_parser.add_argument("--model_type", type=str, default="RCC", help="")

View File

@ -113,17 +113,25 @@ class BinaryMaskDatasetMixin:
# Dataset # Dataset
# ============================================================================= # =============================================================================
# Mel Transforms # Mel Transforms
mel_transforms_train = Compose([
# Audio to Mel Transformations
Speed(speed_factor=self.params.speed_factor, max_ratio=self.params.speed_ratio),
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft,
hop_length=self.params.hop_length),
MelToImage()])
mel_transforms = Compose([ mel_transforms = Compose([
# Audio to Mel Transformations # Audio to Mel Transformations
AudioToMel(sr=self.params.sr, n_mels=self.params.n_mels, n_fft=self.params.n_fft, AudioToMel(sr=self.params.sr,
n_mels=self.params.n_mels,
n_fft=self.params.n_fft,
hop_length=self.params.hop_length), hop_length=self.params.hop_length),
MelToImage()]) MelToImage()])
mel_transforms_train = Compose([
# Audio to Mel Transformations
Speed(max_amount=self.params.speed_amount,
speed_min=self.params.speed_min,
speed_max=self.params.speed_max
),
mel_transforms])
# Utility
util_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations # Data Augmentations
aug_transforms = Compose([ aug_transforms = Compose([
RandomApply([ RandomApply([
@ -132,11 +140,7 @@ class BinaryMaskDatasetMixin:
ShiftTime(self.params.shift_ratio), ShiftTime(self.params.shift_ratio),
MaskAug(self.params.mask_ratio), MaskAug(self.params.mask_ratio),
], p=0.6), ], p=0.6),
# Utility util_transforms])
NormalizeLocal(),
ToTensor()
])
val_transforms = Compose([NormalizeLocal(), ToTensor()])
# Datasets # Datasets
from datasets.binar_masks import BinaryMasksDataset from datasets.binar_masks import BinaryMasksDataset
@ -149,12 +153,12 @@ class BinaryMaskDatasetMixin:
mel_transforms=mel_transforms_train, transforms=aug_transforms), mel_transforms=mel_transforms_train, transforms=aug_transforms),
# VALIDATION DATASET # VALIDATION DATASET
val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train, val_train_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.train,
mel_transforms=mel_transforms, transforms=val_transforms), mel_transforms=mel_transforms, transforms=util_transforms),
val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel, val_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.devel,
mel_transforms=mel_transforms, transforms=val_transforms), mel_transforms=mel_transforms, transforms=util_transforms),
# TEST DATASET # TEST DATASET
test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test, test_dataset=BinaryMasksDataset(self.params.root, setting=V.DATA_OPTIONS.test,
mel_transforms=mel_transforms, transforms=val_transforms), mel_transforms=mel_transforms, transforms=util_transforms),
) )
) )
return dataset return dataset