118 lines
5.1 KiB
Python
118 lines
5.1 KiB
Python
import csv
|
|
import pickle
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
from sklearn import metrics
|
|
from tqdm import tqdm
|
|
|
|
import numpy as np
|
|
|
|
from util.config import MConfig
|
|
|
|
|
|
def accumulate_predictions(config_filename, output_folders):
|
|
for output_folder in tqdm(output_folders, total=len(output_folders)):
|
|
# Gather Predictions and labels
|
|
inference_files = output_folder.glob('*.csv')
|
|
|
|
config = MConfig()
|
|
config.read_file((output_folder.parent / config_filename).open('r'))
|
|
|
|
result_dict = defaultdict(list)
|
|
for inf_idx, inference_file in enumerate(inference_files):
|
|
with inference_file.open('r') as f:
|
|
# Read Headers to skip the first line
|
|
_ = f.readline()
|
|
for row in f:
|
|
prediction, label = [float(x) for x in row.strip().split(',')]
|
|
result_dict[inference_file.name[:-4]].append(prediction)
|
|
if inf_idx == 0:
|
|
result_dict['labels'].append(label)
|
|
result_dict = dict(result_dict)
|
|
with (output_folder / Path(__file__).name[:-3]).open('wb') as f:
|
|
pickle.dump(result_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
pass
|
|
|
|
|
|
def accumulate_uars(output_folders):
|
|
for model_type in output_folders.iterdir():
|
|
for param_config in model_type.iterdir():
|
|
per_seed_ensemble_files = param_config.rglob(Path(__file__).name[:-3])
|
|
|
|
for ensemble_file in per_seed_ensemble_files:
|
|
uar_dict = dict()
|
|
with ensemble_file.open('rb') as f:
|
|
loaded_ensemble_file = pickle.load(f)
|
|
labels = loaded_ensemble_file.pop('labels')
|
|
for decision_boundry in range(10, 91, 5):
|
|
decision_boundry = round(decision_boundry * 0.01, 2)
|
|
majority_votes = []
|
|
mean_votes = []
|
|
voters = len(loaded_ensemble_file.keys()) * 0.5
|
|
for i in range(len(labels)):
|
|
majority_vote = []
|
|
mean_vote = []
|
|
for key in loaded_ensemble_file.keys():
|
|
majority_vote.append(loaded_ensemble_file[key][i] > decision_boundry)
|
|
mean_vote.append(loaded_ensemble_file[key][i])
|
|
mean_votes.append(int(sum(mean_vote) / len(loaded_ensemble_file.keys()) > decision_boundry))
|
|
majority_votes.append(sum(majority_vote) > voters)
|
|
|
|
for predictions, name in zip([mean_votes, majority_votes], ['mean', 'majority']):
|
|
|
|
uar_score = metrics.recall_score(labels, predictions, labels=[0, 1], average='macro',
|
|
sample_weight=None, zero_division='warn')
|
|
uar_dict[f'{name}_decb_{decision_boundry}'] = uar_score
|
|
with (ensemble_file.parent / 'ensemble_uar_dict_decb').open('wb') as ef:
|
|
pickle.dump(uar_dict, ef, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
|
|
def gather_results(config_filename, outpath):
|
|
for model_type in outpath.iterdir():
|
|
result_dict = defaultdict(list)
|
|
for param_config in model_type.iterdir():
|
|
tmp_uar_dict = defaultdict(list)
|
|
config: MConfig
|
|
for idx, version_uar in enumerate(param_config.rglob('uar_dict_decb')):
|
|
if not idx:
|
|
config = MConfig()
|
|
config.read_file((version_uar.parent.parent / config_filename).open('r'))
|
|
for parameter, value in config.model_paramters.items():
|
|
if parameter in ['exp_path', 'exp_fingerprint', 'loudness_ratio', 'mask_ratio', 'noise_ratio',
|
|
'shift_ratio', 'speed_amount', 'speed_max', 'speed_min']:
|
|
result_dict[parameter].append(value)
|
|
|
|
with version_uar.open('rb') as f:
|
|
loaded_uar_file = pickle.load(f)
|
|
|
|
for key in loaded_uar_file.keys():
|
|
tmp_uar_dict[key].append(loaded_uar_file[key])
|
|
for key, value in tmp_uar_dict.items():
|
|
result_dict[f'{key}_mean'].append(np.mean(value))
|
|
result_dict[f'{key}_std'].append(np.std(value))
|
|
with (model_type / 'checkpoint_ensemble_results.csv').open('w') as f:
|
|
headers = list(result_dict.keys())
|
|
|
|
writer = csv.DictWriter(f, delimiter=',', lineterminator='\n', fieldnames=headers)
|
|
writer.writeheader() # write a header
|
|
|
|
for row_idx in range(len(result_dict['exp_path'])):
|
|
writer.writerow({key: result_dict[key][row_idx] for key in headers})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
outpath = Path().absolute().parent / 'output'
|
|
|
|
config_filename = 'config.ini'
|
|
output_folders_path = list(outpath.rglob('outputs'))
|
|
|
|
# Accumulate the Predictions
|
|
#accumulate_predictions(config_filename, output_folders_path)
|
|
|
|
# Accumulate the UARS
|
|
accumulate_uars(outpath)
|
|
|
|
# Gather Results to final CSV
|
|
#gather_results(config_filename, outpath)
|