341 lines
15 KiB
Plaintext
341 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"pycharm": {
|
|
"name": "#%% IMPORTS\n"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections import defaultdict\n",
|
|
"from pathlib import Path\n",
|
|
"from natsort import natsorted\n",
|
|
"from pytorch_lightning.core.saving import *\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"from sklearn.manifold import TSNE\n",
|
|
"\n",
|
|
"import seaborn as sns\n",
|
|
"from matplotlib import pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [],
|
|
"source": [
|
|
"from ml_lib.metrics.binary_class_classifictaion import BinaryScores\n",
|
|
"from ml_lib.utils.tools import locate_and_import_class\n",
|
|
"_ROOT = Path()\n",
|
|
"out_path = 'output'\n",
|
|
"model_name = 'VisualTransformer'\n",
|
|
"exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'\n",
|
|
"version = 'version_7'"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%M Path resolving and variables\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.style.use('default')\n",
|
|
"sns.set_palette('Dark2')\n",
|
|
"\n",
|
|
"tex_fonts = {\n",
|
|
" # Use LaTeX to write all text\n",
|
|
" \"text.usetex\": True,\n",
|
|
" \"font.family\": \"serif\",\n",
|
|
" # Use 10pt font in plots, to match 10pt font in document\n",
|
|
" \"axes.labelsize\": 10,\n",
|
|
" \"font.size\": 10,\n",
|
|
" # Make the legend/label fonts a little smaller\n",
|
|
" \"legend.fontsize\": 8,\n",
|
|
" \"xtick.labelsize\": 8,\n",
|
|
" \"ytick.labelsize\": 8\n",
|
|
"}\n",
|
|
"\n",
|
|
"# plt.rcParams.update(tex_fonts)\n",
|
|
"\n",
|
|
"Path('figures').mkdir(exist_ok=True)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% Seaborn Settings\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [],
|
|
"source": [
|
|
"def reconstruct_model_data_params(yaml_file_path: str):\n",
|
|
" hparams_dict = load_hparams_from_yaml(yaml_file_path)\n",
|
|
"\n",
|
|
" # Try to get model_name and data_name from yaml:\n",
|
|
" model_name = hparams_dict['model_name']\n",
|
|
" data_name = hparams_dict['data_name']\n",
|
|
" # Try to find the original model and data class by name:\n",
|
|
" found_data_class = locate_and_import_class(data_name, 'datasets')\n",
|
|
" found_model_class = locate_and_import_class(model_name, 'models')\n",
|
|
" # Possible way of automatic loading args:\n",
|
|
" # args = inspect.signature(found_data_class)\n",
|
|
" # then access _parameter.ini and retrieve not set parameters\n",
|
|
"\n",
|
|
" hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')\n",
|
|
"\n",
|
|
" h_params = Namespace(**hparams_dict)\n",
|
|
"\n",
|
|
" # Let Datamodule pull what it wants\n",
|
|
" datamodule = found_data_class.from_argparse_args(h_params)\n",
|
|
"\n",
|
|
" hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)\n",
|
|
"\n",
|
|
" return datamodule, found_model_class, hparams_dict\n",
|
|
"\n",
|
|
"def gather_predictions_and_labels(model, data_option):\n",
|
|
" preds = list()\n",
|
|
" labels = list()\n",
|
|
" filenames = list()\n",
|
|
" with torch.no_grad():\n",
|
|
" for file_name, x, y in datamodule.datasets[data_option]:\n",
|
|
" preds.append(model(x.unsqueeze(0)).main_out)\n",
|
|
" labels.append(y)\n",
|
|
" filenames.append(file_name)\n",
|
|
" labels = np.stack(labels).squeeze()\n",
|
|
" preds = np.stack(preds).squeeze()\n",
|
|
" return preds, labels, filenames\n",
|
|
"\n",
|
|
"def build_tsne_dataframe(preds, labels):\n",
|
|
" tsne = np.stack(TSNE().fit_transform(preds)).squeeze()\n",
|
|
" tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])\n",
|
|
"\n",
|
|
" tsne_dataframe['labels'] = labels\n",
|
|
" tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})\n",
|
|
" return tsne_dataframe\n",
|
|
"\n",
|
|
"def plot_scatterplot(data, option):\n",
|
|
" p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)\n",
|
|
" p.set_title(f'TSNE - distribution of logits for {option}')\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"def redo_predictions(experiment_path, preds, fnames, data_class):\n",
|
|
" sorted_y = defaultdict(list)\n",
|
|
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
|
" sorted_y[fname].append(pred)\n",
|
|
" sorted_y = dict(sorted_y)\n",
|
|
"\n",
|
|
" for file_name in sorted_y:\n",
|
|
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
|
"\n",
|
|
"\n",
|
|
" if data_class.n_classes > 2:\n",
|
|
" pred = np.stack(\n",
|
|
" [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
|
" ).squeeze()\n",
|
|
" class_names = {val: key for val, key in\n",
|
|
" enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}\n",
|
|
" else:\n",
|
|
" pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]\n",
|
|
" pred = np.stack(pred).squeeze()\n",
|
|
" pred = np.where(pred > 0.5, 1, 0)\n",
|
|
" class_names = {val: key for val, key in enumerate(['negative', 'positive'])}\n",
|
|
"\n",
|
|
"\n",
|
|
" df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],\n",
|
|
" prediction=[class_names[x.item()] for x in pred]))\n",
|
|
" result_file = Path(experiment_path / 'predictions_new.csv')\n",
|
|
" if result_file.exists():\n",
|
|
" try:\n",
|
|
" result_file.unlink()\n",
|
|
" except:\n",
|
|
" print('File already existed')\n",
|
|
" pass\n",
|
|
" with result_file.open(mode='wb') as csv_file:\n",
|
|
" df.to_csv(index=False, path_or_buf=csv_file)\n",
|
|
"\n",
|
|
"\n",
|
|
"def re_valida(preds, labels, fnames, data_class):\n",
|
|
" sorted_y = defaultdict(list)\n",
|
|
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
|
" sorted_y[fname].append(pred)\n",
|
|
" sorted_y = dict(sorted_y)\n",
|
|
"\n",
|
|
" for file_name in sorted_y:\n",
|
|
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
|
"\n",
|
|
" for key, val in list(sorted_y.items()):\n",
|
|
" if val.ndim > 1:\n",
|
|
" val = val.mean(axis=0)\n",
|
|
" print(val.ndim)\n",
|
|
" if not val[0] > 0.8:\n",
|
|
" val[0] = 0\n",
|
|
" sorted_y[key] = val\n",
|
|
"\n",
|
|
" pred = np.stack(\n",
|
|
" [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
|
" ).squeeze()\n",
|
|
"\n",
|
|
" one_hot_targets = np.eye(data_class.n_classes)[pred]\n",
|
|
"\n",
|
|
" # Sklearn Scores\n",
|
|
" print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))\n"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% Utility Functions\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Selected Checkopint is output\\VisualTransformer\\VT_7899c07a4809a45c57cba58047cefb5e\\version_7\\ckpt_weights-v1.ckpt\n",
|
|
"PrimatesLibrosaDatamodule\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
|
|
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]\n",
|
|
"print(f'Selected Checkopint is {checkpoint}')\n",
|
|
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
|
|
"print(load_hparams_from_yaml(hparams_yaml)['data_name'])\n",
|
|
"# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE\n",
|
|
"datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())\n",
|
|
"# h_params.update(return_logits=True)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [
|
|
{
|
|
"ename": "RuntimeError",
|
|
"evalue": "Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
|
"\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)",
|
|
"\u001B[1;32m<ipython-input-6-c8e208607217>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mh_params\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0meval\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 2\u001B[0m \u001B[0mdatamodule\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprepare_data\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 202\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 203\u001B[0m \u001B[1;31m# load the state_dict on the model automatically\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 204\u001B[1;33m \u001B[0mmodel\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;34m'state_dict'\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 205\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 206\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36mload_state_dict\u001B[1;34m(self, state_dict, strict)\u001B[0m\n\u001B[0;32m 1049\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1050\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0merror_msgs\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1051\u001B[1;33m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001B[0m\u001B[0;32m 1052\u001B[0m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001B[0;32m 1053\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_IncompatibleKeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmissing_keys\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0munexpected_keys\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
|
"\u001B[1;31mRuntimeError\u001B[0m: Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30])."
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()\n",
|
|
"datamodule.prepare_data()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)\n",
|
|
"# plot_scatterplot(tsne_dataframe, data_option)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [
|
|
"re_valida(predictions,labels_y, filenames, datamodule)"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |