Parameter Adjustmens and Ensemble Model Implementation
This commit is contained in:
@ -1,44 +1,70 @@
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose, ToTensor
|
||||
from pathlib import Path
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import variables as V
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
|
||||
|
||||
# Dataset and Dataloaders
|
||||
# =============================================================================
|
||||
|
||||
# Transforms
|
||||
from ml_lib.utils.logging import Logger
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
from util.config import MConfig
|
||||
from util.logging import MLogger
|
||||
|
||||
transforms = Compose([AudioToMel(), ToTensor(), NormalizeLocal()])
|
||||
|
||||
# Datasets
|
||||
from datasets.binar_masks import BinaryMasksDataset
|
||||
|
||||
|
||||
def prepare_dataset(config_obj):
|
||||
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test', transforms=transforms)
|
||||
return DataLoader(dataset=dataset,
|
||||
batch_size=None,
|
||||
worker=config_obj.data.worker,
|
||||
shuffle=False)
|
||||
def prepare_dataloader(config_obj):
|
||||
mel_transforms = Compose([AudioToMel(n_mels=config_obj.data.n_mels), MelToImage()])
|
||||
transforms = Compose([NormalizeLocal(), ToTensor()])
|
||||
|
||||
dataset: Dataset = BinaryMasksDataset(config_obj.data.root, setting='test',
|
||||
mel_transforms=mel_transforms, transforms=transforms
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
return DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False)
|
||||
|
||||
|
||||
def restore_logger_and_model(config_obj):
|
||||
logger = MLogger(config_obj)
|
||||
model = SavedLightningModels().load_checkpoint(models_root_path=logger.log_dir)
|
||||
logger = Logger(config_obj)
|
||||
model = SavedLightningModels.load_checkpoint(models_root_path=logger.log_dir, n=-2)
|
||||
model = model.restore()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
else:
|
||||
model.cpu()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from _paramters import main_arg_parser
|
||||
outpath = Path('output')
|
||||
model_type = 'BandwiseConvMultiheadClassifier'
|
||||
parameters = 'BCMC_9c70168a5711c269b33701f1650adfb9/'
|
||||
version = 'version_1'
|
||||
config_filename = 'config.ini'
|
||||
inference_out = 'manual_test_out.csv'
|
||||
|
||||
config = MConfig().read_argparser(main_arg_parser)
|
||||
test_dataset = prepare_dataset(config)
|
||||
config = MConfig()
|
||||
config.read_file((outpath / model_type / parameters / version / config_filename).open('r'))
|
||||
test_dataloader = prepare_dataloader(config)
|
||||
loaded_model = restore_logger_and_model(config)
|
||||
print("run model here and find a format to store the output")
|
||||
loaded_model.eval()
|
||||
|
||||
with (outpath / model_type / parameters / version / inference_out).open(mode='w') as outfile:
|
||||
outfile.write(f'file_name,prediction\n')
|
||||
|
||||
for batch in tqdm(test_dataloader, total=len(test_dataloader)):
|
||||
batch_x, file_name = batch
|
||||
y = loaded_model(batch_x.unsqueeze(0).to(device='cuda' if torch.cuda.is_available() else 'cpu')).main_out
|
||||
prediction = (y.squeeze() >= 0.5).int().item()
|
||||
prediction = 'clear' if prediction == V.CLEAR else 'mask'
|
||||
outfile.write(f'{file_name},{prediction}\n')
|
||||
print('Done')
|
||||
|
Reference in New Issue
Block a user