Model Training
This commit is contained in:
44
main_inference.py
Normal file
44
main_inference.py
Normal file
@ -0,0 +1,44 @@
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose, ToTensor
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import Melspectogram, NormalizeLocal
|
||||
|
||||
# Dataset and Dataloaders
|
||||
# =============================================================================
|
||||
|
||||
# Transforms
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from util.config import MConfig
|
||||
from util.logging import MLogger
|
||||
|
||||
transforms = Compose([Melspectogram(), ToTensor(), NormalizeLocal()])
|
||||
|
||||
# Datasets
|
||||
from datasets.binar_masks import BinaryMasksDataset
|
||||
|
||||
|
||||
def prepare_dataset(config_obj):
|
||||
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', transforms=transforms)
|
||||
return DataLoader(dataset=dataset,
|
||||
batch_size=None,
|
||||
worker=config_obj.data.worker,
|
||||
shuffle=False)
|
||||
|
||||
|
||||
def restore_logger_and_model(config_obj):
|
||||
logger = MLogger(config_obj)
|
||||
model = SavedLightningModels().load_checkpoint(models_root_path=logger.log_dir)
|
||||
model = model.restore()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from _paramters import main_arg_parser
|
||||
|
||||
config = MConfig().read_argparser(main_arg_parser)
|
||||
test_dataset = prepare_dataset(config)
|
||||
loaded_model = restore_logger_and_model(config)
|
||||
print("run model here and find a format to store the output")
|
||||
|
||||
|
Reference in New Issue
Block a user