45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision.transforms import Compose, ToTensor
|
|
|
|
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal
|
|
|
|
# Dataset and Dataloaders
|
|
# =============================================================================
|
|
|
|
# Transforms
|
|
from ml_lib.utils.model_io import SavedLightningModels
|
|
from util.config import MConfig
|
|
from util.logging import MLogger
|
|
|
|
transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()])
|
|
|
|
# Datasets
|
|
from datasets.binar_masks import BinaryMasksDataset
|
|
|
|
|
|
def prepare_dataset(config_obj):
|
|
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', transforms=transforms)
|
|
return DataLoader(dataset=dataset,
|
|
batch_size=None,
|
|
worker=config_obj.data.worker,
|
|
shuffle=False)
|
|
|
|
|
|
def restore_logger_and_model(config_obj):
|
|
logger = MLogger(config_obj)
|
|
model = SavedLightningModels().load_checkpoint(models_root_path=logger.log_dir)
|
|
model = model.restore()
|
|
|
|
return model
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from _paramters import main_arg_parser
|
|
|
|
config = MConfig().read_argparser(main_arg_parser)
|
|
test_dataset = prepare_dataset(config)
|
|
loaded_model = restore_logger_and_model(config)
|
|
print("run model here and find a format to store the output")
|
|
|
|
|