ensembles
This commit is contained in:
117
ensemble_methods/ensemble_checkpoints.py
Normal file
117
ensemble_methods/ensemble_checkpoints.py
Normal file
@ -0,0 +1,117 @@
|
||||
import csv
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from sklearn import metrics
|
||||
from tqdm import tqdm
|
||||
|
||||
import numpy as np
|
||||
|
||||
from util.config import MConfig
|
||||
|
||||
|
||||
def accumulate_predictions(config_filename, output_folders):
|
||||
for output_folder in tqdm(output_folders, total=len(output_folders)):
|
||||
# Gather Predictions and labels
|
||||
inference_files = output_folder.glob('*.csv')
|
||||
|
||||
config = MConfig()
|
||||
config.read_file((output_folder.parent / config_filename).open('r'))
|
||||
|
||||
result_dict = defaultdict(list)
|
||||
for inf_idx, inference_file in enumerate(inference_files):
|
||||
with inference_file.open('r') as f:
|
||||
# Read Headers to skip the first line
|
||||
_ = f.readline()
|
||||
for row in f:
|
||||
prediction, label = [float(x) for x in row.strip().split(',')]
|
||||
result_dict[inference_file.name[:-4]].append(prediction)
|
||||
if inf_idx == 0:
|
||||
result_dict['labels'].append(label)
|
||||
result_dict = dict(result_dict)
|
||||
with (output_folder / Path(__file__).name[:-3]).open('wb') as f:
|
||||
pickle.dump(result_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pass
|
||||
|
||||
|
||||
def accumulate_uars(output_folders):
|
||||
for model_type in output_folders.iterdir():
|
||||
for param_config in model_type.iterdir():
|
||||
per_seed_ensemble_files = param_config.rglob(Path(__file__).name[:-3])
|
||||
|
||||
for ensemble_file in per_seed_ensemble_files:
|
||||
uar_dict = dict()
|
||||
with ensemble_file.open('rb') as f:
|
||||
loaded_ensemble_file = pickle.load(f)
|
||||
labels = loaded_ensemble_file.pop('labels')
|
||||
for decision_boundry in range(10, 91, 5):
|
||||
decision_boundry = round(decision_boundry * 0.01, 2)
|
||||
majority_votes = []
|
||||
mean_votes = []
|
||||
voters = len(loaded_ensemble_file.keys()) * 0.5
|
||||
for i in range(len(labels)):
|
||||
majority_vote = []
|
||||
mean_vote = []
|
||||
for key in loaded_ensemble_file.keys():
|
||||
majority_vote.append(loaded_ensemble_file[key][i] > decision_boundry)
|
||||
mean_vote.append(loaded_ensemble_file[key][i])
|
||||
mean_votes.append(int(sum(mean_vote) / len(loaded_ensemble_file.keys()) > decision_boundry))
|
||||
majority_votes.append(sum(majority_vote) > voters)
|
||||
|
||||
for predictions, name in zip([mean_votes, majority_votes], ['mean', 'majority']):
|
||||
|
||||
uar_score = metrics.recall_score(labels, predictions, labels=[0, 1], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
uar_dict[f'{name}_decb_{decision_boundry}'] = uar_score
|
||||
with (ensemble_file.parent / 'ensemble_uar_dict_decb').open('wb') as ef:
|
||||
pickle.dump(uar_dict, ef, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def gather_results(config_filename, outpath):
|
||||
for model_type in outpath.iterdir():
|
||||
result_dict = defaultdict(list)
|
||||
for param_config in model_type.iterdir():
|
||||
tmp_uar_dict = defaultdict(list)
|
||||
config: MConfig
|
||||
for idx, version_uar in enumerate(param_config.rglob('uar_dict_decb')):
|
||||
if not idx:
|
||||
config = MConfig()
|
||||
config.read_file((version_uar.parent.parent / config_filename).open('r'))
|
||||
for parameter, value in config.model_paramters.items():
|
||||
if parameter in ['exp_path', 'exp_fingerprint', 'loudness_ratio', 'mask_ratio', 'noise_ratio',
|
||||
'shift_ratio', 'speed_amount', 'speed_max', 'speed_min']:
|
||||
result_dict[parameter].append(value)
|
||||
|
||||
with version_uar.open('rb') as f:
|
||||
loaded_uar_file = pickle.load(f)
|
||||
|
||||
for key in loaded_uar_file.keys():
|
||||
tmp_uar_dict[key].append(loaded_uar_file[key])
|
||||
for key, value in tmp_uar_dict.items():
|
||||
result_dict[f'{key}_mean'].append(np.mean(value))
|
||||
result_dict[f'{key}_std'].append(np.std(value))
|
||||
with (model_type / 'checkpoint_ensemble_results.csv').open('w') as f:
|
||||
headers = list(result_dict.keys())
|
||||
|
||||
writer = csv.DictWriter(f, delimiter=',', lineterminator='\n', fieldnames=headers)
|
||||
writer.writeheader() # write a header
|
||||
|
||||
for row_idx in range(len(result_dict['exp_path'])):
|
||||
writer.writerow({key: result_dict[key][row_idx] for key in headers})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
outpath = Path().absolute().parent / 'output'
|
||||
|
||||
config_filename = 'config.ini'
|
||||
output_folders_path = list(outpath.rglob('outputs'))
|
||||
|
||||
# Accumulate the Predictions
|
||||
#accumulate_predictions(config_filename, output_folders_path)
|
||||
|
||||
# Accumulate the UARS
|
||||
accumulate_uars(outpath)
|
||||
|
||||
# Gather Results to final CSV
|
||||
#gather_results(config_filename, outpath)
|
Reference in New Issue
Block a user