masks_augments_compare-21/main_inference.py

45 lines
1.3 KiB
Python

from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal
# Dataset and Dataloaders
# =============================================================================
# Transforms
from ml_lib.utils.model_io import SavedLightningModels
from util.config import MConfig
from util.logging import MLogger
transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()])
# Datasets
from datasets.binar_masks import BinaryMasksDataset
def prepare_dataset(config_obj):
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', transforms=transforms)
return DataLoader(dataset=dataset,
batch_size=None,
worker=config_obj.data.worker,
shuffle=False)
def restore_logger_and_model(config_obj):
logger = MLogger(config_obj)
model = SavedLightningModels().load_checkpoint(models_root_path=logger.log_dir)
model = model.restore()
return model
if __name__ == '__main__':
from _paramters import main_arg_parser
config = MConfig().read_argparser(main_arg_parser)
test_dataset = prepare_dataset(config)
loaded_model = restore_logger_and_model(config)
print("run model here and find a format to store the output")