{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "pycharm": { "name": "#%% IMPORTS\n" } }, "outputs": [], "source": [ "from collections import defaultdict\n", "from pathlib import Path\n", "from natsort import natsorted\n", "from pytorch_lightning.core.saving import *\n", "\n", "import torch\n", "from sklearn.manifold import TSNE\n", "\n", "import seaborn as sns\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "from ml_lib.metrics.binary_class_classifictaion import BinaryScores\n", "from ml_lib.utils.tools import locate_and_import_class\n", "_ROOT = Path()\n", "out_path = 'output'\n", "model_name = 'VisualTransformer'\n", "exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'\n", "version = 'version_7'" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%M Path resolving and variables\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "plt.style.use('default')\n", "sns.set_palette('Dark2')\n", "\n", "tex_fonts = {\n", " # Use LaTeX to write all text\n", " \"text.usetex\": True,\n", " \"font.family\": \"serif\",\n", " # Use 10pt font in plots, to match 10pt font in document\n", " \"axes.labelsize\": 10,\n", " \"font.size\": 10,\n", " # Make the legend/label fonts a little smaller\n", " \"legend.fontsize\": 8,\n", " \"xtick.labelsize\": 8,\n", " \"ytick.labelsize\": 8\n", "}\n", "\n", "# plt.rcParams.update(tex_fonts)\n", "\n", "Path('figures').mkdir(exist_ok=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% Seaborn Settings\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "def reconstruct_model_data_params(yaml_file_path: str):\n", " hparams_dict = load_hparams_from_yaml(yaml_file_path)\n", "\n", " # Try to get model_name and data_name from yaml:\n", " model_name = hparams_dict['model_name']\n", " data_name = hparams_dict['data_name']\n", " # Try to find the original model and data class by name:\n", " found_data_class = locate_and_import_class(data_name, 'datasets')\n", " found_model_class = locate_and_import_class(model_name, 'models')\n", " # Possible way of automatic loading args:\n", " # args = inspect.signature(found_data_class)\n", " # then access _parameter.ini and retrieve not set parameters\n", "\n", " hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')\n", "\n", " h_params = Namespace(**hparams_dict)\n", "\n", " # Let Datamodule pull what it wants\n", " datamodule = found_data_class.from_argparse_args(h_params)\n", "\n", " hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)\n", "\n", " return datamodule, found_model_class, hparams_dict\n", "\n", "def gather_predictions_and_labels(model, data_option):\n", " preds = list()\n", " labels = list()\n", " filenames = list()\n", " with torch.no_grad():\n", " for file_name, x, y in datamodule.datasets[data_option]:\n", " preds.append(model(x.unsqueeze(0)).main_out)\n", " labels.append(y)\n", " filenames.append(file_name)\n", " labels = np.stack(labels).squeeze()\n", " preds = np.stack(preds).squeeze()\n", " return preds, labels, filenames\n", "\n", "def build_tsne_dataframe(preds, labels):\n", " tsne = np.stack(TSNE().fit_transform(preds)).squeeze()\n", " tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])\n", "\n", " tsne_dataframe['labels'] = labels\n", " tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})\n", " return tsne_dataframe\n", "\n", "def plot_scatterplot(data, option):\n", " p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)\n", " p.set_title(f'TSNE - distribution of logits for {option}')\n", " plt.show()\n", "\n", "def redo_predictions(experiment_path, preds, fnames, data_class):\n", " sorted_y = defaultdict(list)\n", " for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n", " sorted_y[fname].append(pred)\n", " sorted_y = dict(sorted_y)\n", "\n", " for file_name in sorted_y:\n", " sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n", "\n", "\n", " if data_class.n_classes > 2:\n", " pred = np.stack(\n", " [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n", " ).squeeze()\n", " class_names = {val: key for val, key in\n", " enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}\n", " else:\n", " pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]\n", " pred = np.stack(pred).squeeze()\n", " pred = np.where(pred > 0.5, 1, 0)\n", " class_names = {val: key for val, key in enumerate(['negative', 'positive'])}\n", "\n", "\n", " df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],\n", " prediction=[class_names[x.item()] for x in pred]))\n", " result_file = Path(experiment_path / 'predictions_new.csv')\n", " if result_file.exists():\n", " try:\n", " result_file.unlink()\n", " except:\n", " print('File already existed')\n", " pass\n", " with result_file.open(mode='wb') as csv_file:\n", " df.to_csv(index=False, path_or_buf=csv_file)\n", "\n", "\n", "def re_valida(preds, labels, fnames, data_class):\n", " sorted_y = defaultdict(list)\n", " for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n", " sorted_y[fname].append(pred)\n", " sorted_y = dict(sorted_y)\n", "\n", " for file_name in sorted_y:\n", " sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n", "\n", " for key, val in list(sorted_y.items()):\n", " if val.ndim > 1:\n", " val = val.mean(axis=0)\n", " print(val.ndim)\n", " if not val[0] > 0.8:\n", " val[0] = 0\n", " sorted_y[key] = val\n", "\n", " pred = np.stack(\n", " [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n", " ).squeeze()\n", "\n", " one_hot_targets = np.eye(data_class.n_classes)[pred]\n", "\n", " # Sklearn Scores\n", " print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% Utility Functions\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Selected Checkopint is output\\VisualTransformer\\VT_7899c07a4809a45c57cba58047cefb5e\\version_7\\ckpt_weights-v1.ckpt\n", "PrimatesLibrosaDatamodule\n" ] } ], "source": [ "exp_path = _ROOT / out_path / model_name / exp_name / version\n", "checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]\n", "print(f'Selected Checkopint is {checkpoint}')\n", "hparams_yaml = next(exp_path.glob('*.yaml'))\n", "print(load_hparams_from_yaml(hparams_yaml)['data_name'])\n", "# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE\n", "datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())\n", "# h_params.update(return_logits=True)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "ename": "RuntimeError", "evalue": "Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).", "output_type": "error", "traceback": [ "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)", "\u001B[1;32m\u001B[0m in \u001B[0;36m\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mh_params\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0meval\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 2\u001B[0m \u001B[0mdatamodule\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprepare_data\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 202\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 203\u001B[0m \u001B[1;31m# load the state_dict on the model automatically\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 204\u001B[1;33m \u001B[0mmodel\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;34m'state_dict'\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 205\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 206\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", "\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36mload_state_dict\u001B[1;34m(self, state_dict, strict)\u001B[0m\n\u001B[0;32m 1049\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1050\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0merror_msgs\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1051\u001B[1;33m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001B[0m\u001B[0;32m 1052\u001B[0m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001B[0;32m 1053\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_IncompatibleKeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmissing_keys\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0munexpected_keys\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", "\u001B[1;31mRuntimeError\u001B[0m: Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30])." ] } ], "source": [ "model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()\n", "datamodule.prepare_data()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)\n", "# plot_scatterplot(tsne_dataframe, data_option)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "re_valida(predictions,labels_y, filenames, datamodule)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }