requirements

This commit is contained in:
Si11ium
2020-05-14 23:08:36 +02:00
parent 407df15bbf
commit e7d1a4895a
9 changed files with 52 additions and 38 deletions

View File

@ -5,7 +5,7 @@ from tqdm import tqdm
import variables as V
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
@ -13,6 +13,7 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
# =============================================================================
# Transforms
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils.logging import Logger
from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor
@ -28,8 +29,18 @@ def prepare_dataloader(config_obj):
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
hop_length=config_obj.data.hop_length), MelToImage()])
transforms = Compose([NormalizeLocal(), ToTensor()])
aug_transforms = Compose([
RandomApply([
NoiseInjection(config_obj.data.noise_ratio),
LoudnessManipulator(config_obj.data.loudness_ratio),
ShiftTime(config_obj.data.shift_ratio),
MaskAug(config_obj.data.mask_ratio),
], p=0.6),
# Utility
NormalizeLocal(), ToTensor()
])
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='train',
mel_transforms=mel_transforms, transforms=transforms
)
# noinspection PyTypeChecker
@ -49,9 +60,9 @@ def restore_logger_and_model(config_obj):
if __name__ == '__main__':
outpath = Path('output')
model_type = 'BandwiseConvMultiheadClassifier'
parameters = 'BCMC_9c70168a5711c269b33701f1650adfb9/'
version = 'version_1'
model_type = 'CC'
parameters = 'CC_213adb16e46592c5a405abfbd693835e/'
version = 'version_41'
config_filename = 'config.ini'
inference_out = 'manual_test_out.csv'