Mel_Vision_Transformer_ComP.../reload model.ipynb
2021-04-02 08:45:11 +02:00

341 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% IMPORTS\n"
}
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"from pathlib import Path\n",
"from natsort import natsorted\n",
"from pytorch_lightning.core.saving import *\n",
"\n",
"import torch\n",
"from sklearn.manifold import TSNE\n",
"\n",
"import seaborn as sns\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"from ml_lib.metrics.binary_class_classifictaion import BinaryScores\n",
"from ml_lib.utils.tools import locate_and_import_class\n",
"_ROOT = Path()\n",
"out_path = 'output'\n",
"model_name = 'VisualTransformer'\n",
"exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'\n",
"version = 'version_7'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%M Path resolving and variables\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"plt.style.use('default')\n",
"sns.set_palette('Dark2')\n",
"\n",
"tex_fonts = {\n",
" # Use LaTeX to write all text\n",
" \"text.usetex\": True,\n",
" \"font.family\": \"serif\",\n",
" # Use 10pt font in plots, to match 10pt font in document\n",
" \"axes.labelsize\": 10,\n",
" \"font.size\": 10,\n",
" # Make the legend/label fonts a little smaller\n",
" \"legend.fontsize\": 8,\n",
" \"xtick.labelsize\": 8,\n",
" \"ytick.labelsize\": 8\n",
"}\n",
"\n",
"# plt.rcParams.update(tex_fonts)\n",
"\n",
"Path('figures').mkdir(exist_ok=True)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Seaborn Settings\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"def reconstruct_model_data_params(yaml_file_path: str):\n",
" hparams_dict = load_hparams_from_yaml(yaml_file_path)\n",
"\n",
" # Try to get model_name and data_name from yaml:\n",
" model_name = hparams_dict['model_name']\n",
" data_name = hparams_dict['data_name']\n",
" # Try to find the original model and data class by name:\n",
" found_data_class = locate_and_import_class(data_name, 'datasets')\n",
" found_model_class = locate_and_import_class(model_name, 'models')\n",
" # Possible way of automatic loading args:\n",
" # args = inspect.signature(found_data_class)\n",
" # then access _parameter.ini and retrieve not set parameters\n",
"\n",
" hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')\n",
"\n",
" h_params = Namespace(**hparams_dict)\n",
"\n",
" # Let Datamodule pull what it wants\n",
" datamodule = found_data_class.from_argparse_args(h_params)\n",
"\n",
" hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)\n",
"\n",
" return datamodule, found_model_class, hparams_dict\n",
"\n",
"def gather_predictions_and_labels(model, data_option):\n",
" preds = list()\n",
" labels = list()\n",
" filenames = list()\n",
" with torch.no_grad():\n",
" for file_name, x, y in datamodule.datasets[data_option]:\n",
" preds.append(model(x.unsqueeze(0)).main_out)\n",
" labels.append(y)\n",
" filenames.append(file_name)\n",
" labels = np.stack(labels).squeeze()\n",
" preds = np.stack(preds).squeeze()\n",
" return preds, labels, filenames\n",
"\n",
"def build_tsne_dataframe(preds, labels):\n",
" tsne = np.stack(TSNE().fit_transform(preds)).squeeze()\n",
" tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])\n",
"\n",
" tsne_dataframe['labels'] = labels\n",
" tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})\n",
" return tsne_dataframe\n",
"\n",
"def plot_scatterplot(data, option):\n",
" p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)\n",
" p.set_title(f'TSNE - distribution of logits for {option}')\n",
" plt.show()\n",
"\n",
"def redo_predictions(experiment_path, preds, fnames, data_class):\n",
" sorted_y = defaultdict(list)\n",
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
" sorted_y[fname].append(pred)\n",
" sorted_y = dict(sorted_y)\n",
"\n",
" for file_name in sorted_y:\n",
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
"\n",
"\n",
" if data_class.n_classes > 2:\n",
" pred = np.stack(\n",
" [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
" ).squeeze()\n",
" class_names = {val: key for val, key in\n",
" enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}\n",
" else:\n",
" pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]\n",
" pred = np.stack(pred).squeeze()\n",
" pred = np.where(pred > 0.5, 1, 0)\n",
" class_names = {val: key for val, key in enumerate(['negative', 'positive'])}\n",
"\n",
"\n",
" df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],\n",
" prediction=[class_names[x.item()] for x in pred]))\n",
" result_file = Path(experiment_path / 'predictions_new.csv')\n",
" if result_file.exists():\n",
" try:\n",
" result_file.unlink()\n",
" except:\n",
" print('File already existed')\n",
" pass\n",
" with result_file.open(mode='wb') as csv_file:\n",
" df.to_csv(index=False, path_or_buf=csv_file)\n",
"\n",
"\n",
"def re_valida(preds, labels, fnames, data_class):\n",
" sorted_y = defaultdict(list)\n",
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
" sorted_y[fname].append(pred)\n",
" sorted_y = dict(sorted_y)\n",
"\n",
" for file_name in sorted_y:\n",
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
"\n",
" for key, val in list(sorted_y.items()):\n",
" if val.ndim > 1:\n",
" val = val.mean(axis=0)\n",
" print(val.ndim)\n",
" if not val[0] > 0.8:\n",
" val[0] = 0\n",
" sorted_y[key] = val\n",
"\n",
" pred = np.stack(\n",
" [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
" ).squeeze()\n",
"\n",
" one_hot_targets = np.eye(data_class.n_classes)[pred]\n",
"\n",
" # Sklearn Scores\n",
" print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Utility Functions\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Selected Checkopint is output\\VisualTransformer\\VT_7899c07a4809a45c57cba58047cefb5e\\version_7\\ckpt_weights-v1.ckpt\n",
"PrimatesLibrosaDatamodule\n"
]
}
],
"source": [
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]\n",
"print(f'Selected Checkopint is {checkpoint}')\n",
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
"print(load_hparams_from_yaml(hparams_yaml)['data_name'])\n",
"# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE\n",
"datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())\n",
"# h_params.update(return_logits=True)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)",
"\u001B[1;32m<ipython-input-6-c8e208607217>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mh_params\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0meval\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 2\u001B[0m \u001B[0mdatamodule\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprepare_data\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 202\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 203\u001B[0m \u001B[1;31m# load the state_dict on the model automatically\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 204\u001B[1;33m \u001B[0mmodel\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;34m'state_dict'\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 205\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 206\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36mload_state_dict\u001B[1;34m(self, state_dict, strict)\u001B[0m\n\u001B[0;32m 1049\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1050\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0merror_msgs\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1051\u001B[1;33m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001B[0m\u001B[0;32m 1052\u001B[0m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001B[0;32m 1053\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_IncompatibleKeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmissing_keys\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0munexpected_keys\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;31mRuntimeError\u001B[0m: Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30])."
]
}
],
"source": [
"model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()\n",
"datamodule.prepare_data()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)\n",
"# plot_scatterplot(tsne_dataframe, data_option)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"re_valida(predictions,labels_y, filenames, datamodule)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}