adjustment fot CCS, notebook folder
This commit is contained in:
parent
78b3139d1a
commit
c12f3866c8
@ -116,6 +116,7 @@ class CCSLibrosaDatamodule(_BaseDataModule):
|
||||
for sub_dataset in results.get():
|
||||
dataset.append(sub_dataset[0])
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
print(f'{data_option}-dataset prepared.')
|
||||
self.datasets = datasets
|
||||
return datasets
|
||||
|
||||
@ -133,6 +134,7 @@ class CCSLibrosaDatamodule(_BaseDataModule):
|
||||
for row in all_rows:
|
||||
mel_dataset, class_id, _ = self._build_subdataset(row)
|
||||
dataset.append(mel_dataset)
|
||||
print(f'{data_option}-dataset prepared!')
|
||||
datasets[data_option] = ConcatDataset(dataset)
|
||||
|
||||
# Build Weighted Sampler for train and val
|
||||
@ -158,6 +160,7 @@ class CCSLibrosaDatamodule(_BaseDataModule):
|
||||
samplers[data_option] = None
|
||||
self.datasets = datasets
|
||||
self.samplers = samplers
|
||||
print(f'Dataset {self.__class__.__name__} setup done.')
|
||||
return datasets
|
||||
|
||||
def purge(self):
|
||||
|
@ -3,6 +3,7 @@ from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import LinearModule
|
||||
from ml_lib.modules.model_parts import CNNEncoder
|
||||
@ -36,9 +37,11 @@ class CNNBaseline(CombinedModelMixins,
|
||||
# Modules with Parameters
|
||||
self.encoder = CNNEncoder(in_shape=self.in_shape, **self.params.module_kwargs)
|
||||
|
||||
# Make Decision between binary and Multiclass Classification
|
||||
logits = n_classes if n_classes > 2 else 1
|
||||
module_kwargs = self.params.module_kwargs
|
||||
module_kwargs.update(activation=nn.Softmax)
|
||||
self.classifier = LinearModule(self.encoder.shape, n_classes, **module_kwargs)
|
||||
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||
self.classifier = LinearModule(self.encoder.shape, logits, **module_kwargs)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
@ -52,4 +55,7 @@ class CNNBaseline(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
@ -1,8 +1,6 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -70,13 +68,15 @@ class VisualTransformer(CombinedModelMixins,
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
|
@ -11,13 +11,12 @@ if __name__ == '__main__':
|
||||
|
||||
# Set new values
|
||||
hparams_dict = dict(seed=range(10),
|
||||
model_name=['VisualTransformer'],
|
||||
data_name=['CCSLibrosaDatamodule'],
|
||||
model_name=['CNNBaseline'],
|
||||
data_name=['CCSLibrosaDatamodule'], # 'CCSLibrosaDatamodule'],
|
||||
batch_size=[50],
|
||||
max_epochs=[200],
|
||||
variable_length=[False],
|
||||
sample_segment_len=[40],
|
||||
sample_hop_len=[15],
|
||||
target_mel_length_in_seconds=[0.5],
|
||||
random_apply_chance=[0.5], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
|
104
notebooks/reload model.ipynb
Normal file
104
notebooks/reload model.ipynb
Normal file
@ -0,0 +1,104 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% IMPORTS\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from natsort import natsorted\n",
|
||||
"from pytorch_lightning.core.saving import *\n",
|
||||
"from ml_lib.utils.model_io import SavedLightningModels\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ml_lib.utils.tools import locate_and_import_class\n",
|
||||
"from models.transformer_model import VisualTransformer\n",
|
||||
"_ROOT = Path('..')\n",
|
||||
"out_path = 'output'\n",
|
||||
"model_class = VisualTransformer\n",
|
||||
"model_name = model_class.name()\n",
|
||||
"\n",
|
||||
"exp_name = 'VT_01123c93daaffa92d2ed341bda32426d'\n",
|
||||
"version = 'version_2'"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%M Path resolving and variables\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "When you set `reduce` as 'macro', you have to provide the number of classes.",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
|
||||
"\u001B[1;32m<ipython-input-50-0216292a172f>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 6\u001B[0m \u001B[0madditional_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvariable_length\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mFalse\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mc_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m5\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 7\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 8\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhparams_file\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstr\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mhparams_yaml\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0madditional_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 9\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 196\u001B[0m \u001B[0m_cls_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0mk\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mk\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mcls_init_args_name\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 197\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 198\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m**\u001B[0m\u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 199\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 200\u001B[0m \u001B[1;31m# give model a chance to load something\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32m~\\projects\\compare_21\\models\\transformer_model.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval)\u001B[0m\n\u001B[0;32m 27\u001B[0m \u001B[0ma\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mlocals\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 28\u001B[0m \u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0ma\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0marg\u001B[0m \u001B[1;32min\u001B[0m \u001B[0minspect\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msignature\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mkeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0marg\u001B[0m \u001B[1;33m!=\u001B[0m \u001B[1;34m'self'\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 29\u001B[1;33m \u001B[0msuper\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mVisualTransformer\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0min_shape\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0min_shape\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, model_parameters, weight_init)\u001B[0m\n\u001B[0;32m 112\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mModelParameters\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel_parameters\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 113\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 114\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mPLMetrics\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtag\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'PL'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 115\u001B[0m \u001B[1;32mpass\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 116\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, n_classes, tag)\u001B[0m\n\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0maccuracy_score\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mAccuracy\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 32\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprecision\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mPrecision\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 33\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mrecall\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mRecall\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconfusion_matrix\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mConfusionMatrix\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnormalize\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'true'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\precision_recall.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, num_classes, threshold, average, multilabel, mdmc_average, ignore_index, top_k, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 139\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34mf\"The `average` has to be one of {allowed_average}, got {average}.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 140\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 141\u001B[1;33m super().__init__(\n\u001B[0m\u001B[0;32m 142\u001B[0m \u001B[0mreduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0maverage\u001B[0m \u001B[1;32min\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;34m\"weighted\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"none\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32melse\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 143\u001B[0m \u001B[0mmdmc_reduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mmdmc_average\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\stat_scores.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 157\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mreduce\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m<\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 159\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"When you set `reduce` as 'macro', you have to provide the number of classes.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 160\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 161\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mand\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[1;36m0\u001B[0m \u001B[1;33m<=\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;33m<\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;31mValueError\u001B[0m: When you set `reduce` as 'macro', you have to provide the number of classes."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
|
||||
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-1]\n",
|
||||
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
|
||||
"\n",
|
||||
"hparams_file = load_hparams_from_yaml(hparams_yaml)\n",
|
||||
"additional_kwargs = dict(variable_length = False, c_classes=5)\n",
|
||||
"\n",
|
||||
"model = model_class.load_from_checkpoint(checkpoint, hparams_file=str(hparams_yaml), **additional_kwargs)\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
@ -17,13 +17,14 @@ class TrainMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_files, batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
|
||||
if self.params.n_classes <= 2:
|
||||
loss = self.bce_loss(y, batch_y.long())
|
||||
else:
|
||||
if self.params.loss == 'focal_loss_rob':
|
||||
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
|
||||
loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
|
||||
else:
|
||||
loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
|
||||
|
||||
return dict(loss=loss)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
@ -60,13 +61,16 @@ class ValMixin:
|
||||
).squeeze()
|
||||
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
||||
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
||||
if self.params.n_classes <= 2:
|
||||
if y_one_hot.ndim == 1:
|
||||
y_one_hot = y_one_hot.unsqueeze(0)
|
||||
if target_y.ndim == 1:
|
||||
target_y = target_y.unsqueeze(0)
|
||||
target_y = target_y.unsqueeze(-1)
|
||||
|
||||
self.metrics.update(y_one_hot, target_y)
|
||||
|
||||
if self.params.n_classes <= 2:
|
||||
val_loss = self.bce_loss(y.squeeze().float(), batch_y.float())
|
||||
else:
|
||||
val_loss = self.ce_loss(y, batch_y.long())
|
||||
|
||||
return dict(batch_files=batch_files, val_loss=val_loss,
|
||||
@ -93,17 +97,26 @@ class ValMixin:
|
||||
for file_name in sorted_y:
|
||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||
|
||||
y_mean = torch.stack(
|
||||
[torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
|
||||
).squeeze()
|
||||
mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
|
||||
summary_dict.update(val_mean_vote_loss=mean_vote_loss)
|
||||
#y_mean = torch.stack(
|
||||
# [torch.mean(x, dim=0, keepdim=True) if x.shape[0] > 1 else x for x in sorted_y.values()]
|
||||
#).squeeze()
|
||||
|
||||
#if y_mean.ndim == 1:
|
||||
# y_mean = y_mean.unsqueeze(0)
|
||||
#if sorted_batch_y.ndim == 1:
|
||||
# sorted_batch_y = sorted_batch_y.unsqueeze(-1)
|
||||
#
|
||||
#mean_vote_loss = self.ce_loss(y_mean, sorted_batch_y)
|
||||
#summary_dict.update(val_mean_vote_loss=mean_vote_loss)
|
||||
|
||||
y_max = torch.stack(
|
||||
[torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
|
||||
).squeeze()
|
||||
y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
|
||||
if self.params.n_classes >= 2:
|
||||
max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
|
||||
else:
|
||||
max_vote_loss = self.bce_loss(y_one_hot, sorted_batch_y)
|
||||
summary_dict.update(val_max_vote_loss=max_vote_loss)
|
||||
|
||||
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
@ -156,6 +169,8 @@ class TestMixin:
|
||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
elif self.params.n_classes == 2:
|
||||
class_names = {val: key for val, key in ['negative', 'positive']}
|
||||
else:
|
||||
raise AttributeError(f'n_classes has to be any of: [2, 5]')
|
||||
|
||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
||||
prediction=[class_names[x.item()] for x in y_max.cpu()]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user