requirements
This commit is contained in:
@ -5,7 +5,7 @@ from tqdm import tqdm
|
||||
|
||||
import variables as V
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose
|
||||
from torchvision.transforms import Compose, RandomApply
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
|
||||
|
||||
@ -13,6 +13,7 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
|
||||
# =============================================================================
|
||||
|
||||
# Transforms
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.utils.logging import Logger
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
@ -28,8 +29,18 @@ def prepare_dataloader(config_obj):
|
||||
AudioToMel(sr=config_obj.data.sr, n_mels=config_obj.data.n_mels, n_fft=config_obj.data.n_fft,
|
||||
hop_length=config_obj.data.hop_length), MelToImage()])
|
||||
transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
aug_transforms = Compose([
|
||||
RandomApply([
|
||||
NoiseInjection(config_obj.data.noise_ratio),
|
||||
LoudnessManipulator(config_obj.data.loudness_ratio),
|
||||
ShiftTime(config_obj.data.shift_ratio),
|
||||
MaskAug(config_obj.data.mask_ratio),
|
||||
], p=0.6),
|
||||
# Utility
|
||||
NormalizeLocal(), ToTensor()
|
||||
])
|
||||
|
||||
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
|
||||
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='train',
|
||||
mel_transforms=mel_transforms, transforms=transforms
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
@ -49,9 +60,9 @@ def restore_logger_and_model(config_obj):
|
||||
|
||||
if __name__ == '__main__':
|
||||
outpath = Path('output')
|
||||
model_type = 'BandwiseConvMultiheadClassifier'
|
||||
parameters = 'BCMC_9c70168a5711c269b33701f1650adfb9/'
|
||||
version = 'version_1'
|
||||
model_type = 'CC'
|
||||
parameters = 'CC_213adb16e46592c5a405abfbd693835e/'
|
||||
version = 'version_41'
|
||||
config_filename = 'config.ini'
|
||||
inference_out = 'manual_test_out.csv'
|
||||
|
||||
|
Reference in New Issue
Block a user