2021-03-22 16:43:19 +01:00

104 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% IMPORTS\n"
}
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from natsort import natsorted\n",
"from pytorch_lightning.core.saving import *\n",
"from ml_lib.utils.model_io import SavedLightningModels\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"outputs": [],
"source": [
"from ml_lib.utils.tools import locate_and_import_class\n",
"from models.transformer_model import VisualTransformer\n",
"_ROOT = Path('..')\n",
"out_path = 'output'\n",
"model_class = VisualTransformer\n",
"model_name = model_class.name()\n",
"\n",
"exp_name = 'VT_01123c93daaffa92d2ed341bda32426d'\n",
"version = 'version_2'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%M Path resolving and variables\n"
}
}
},
{
"cell_type": "code",
"execution_count": 50,
"outputs": [
{
"ename": "ValueError",
"evalue": "When you set `reduce` as 'macro', you have to provide the number of classes.",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
"\u001B[1;32m<ipython-input-50-0216292a172f>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 6\u001B[0m \u001B[0madditional_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvariable_length\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mFalse\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mc_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m5\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 7\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 8\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhparams_file\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstr\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mhparams_yaml\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0madditional_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 9\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 196\u001B[0m \u001B[0m_cls_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0mk\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mk\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mcls_init_args_name\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 197\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 198\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m**\u001B[0m\u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 199\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 200\u001B[0m \u001B[1;31m# give model a chance to load something\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m~\\projects\\compare_21\\models\\transformer_model.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval)\u001B[0m\n\u001B[0;32m 27\u001B[0m \u001B[0ma\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mlocals\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 28\u001B[0m \u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0ma\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0marg\u001B[0m \u001B[1;32min\u001B[0m \u001B[0minspect\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msignature\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mkeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0marg\u001B[0m \u001B[1;33m!=\u001B[0m \u001B[1;34m'self'\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 29\u001B[1;33m \u001B[0msuper\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mVisualTransformer\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0min_shape\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0min_shape\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, model_parameters, weight_init)\u001B[0m\n\u001B[0;32m 112\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mModelParameters\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel_parameters\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 113\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 114\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mPLMetrics\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtag\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'PL'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 115\u001B[0m \u001B[1;32mpass\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 116\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, n_classes, tag)\u001B[0m\n\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0maccuracy_score\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mAccuracy\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 32\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprecision\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mPrecision\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 33\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mrecall\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mRecall\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconfusion_matrix\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mConfusionMatrix\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnormalize\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'true'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\precision_recall.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, num_classes, threshold, average, multilabel, mdmc_average, ignore_index, top_k, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 139\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34mf\"The `average` has to be one of {allowed_average}, got {average}.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 140\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 141\u001B[1;33m super().__init__(\n\u001B[0m\u001B[0;32m 142\u001B[0m \u001B[0mreduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0maverage\u001B[0m \u001B[1;32min\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;34m\"weighted\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"none\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32melse\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 143\u001B[0m \u001B[0mmdmc_reduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mmdmc_average\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\stat_scores.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 157\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mreduce\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m<\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 159\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"When you set `reduce` as 'macro', you have to provide the number of classes.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 160\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 161\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mand\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[1;36m0\u001B[0m \u001B[1;33m<=\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;33m<\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;31mValueError\u001B[0m: When you set `reduce` as 'macro', you have to provide the number of classes."
]
}
],
"source": [
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-1]\n",
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
"\n",
"hparams_file = load_hparams_from_yaml(hparams_yaml)\n",
"additional_kwargs = dict(variable_length = False, c_classes=5)\n",
"\n",
"model = model_class.load_from_checkpoint(checkpoint, hparams_file=str(hparams_yaml), **additional_kwargs)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}