In [1]:
from collections import defaultdict
from pathlib import Path
from natsort import natsorted
from pytorch_lightning.core.saving import *

import torch
from sklearn.manifold import TSNE

import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

In [2]:
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
from ml_lib.utils.tools import locate_and_import_class
_ROOT = Path()
out_path = 'output'
model_name = 'VisualTransformer'
exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'
version = 'version_7'

In [3]:
plt.style.use('default')
sns.set_palette('Dark2')

tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": True,
    "font.family": "serif",
    # Use 10pt font in plots, to match 10pt font in document
    "axes.labelsize": 10,
    "font.size": 10,
    # Make the legend/label fonts a little smaller
    "legend.fontsize": 8,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8
}

# plt.rcParams.update(tex_fonts)

Path('figures').mkdir(exist_ok=True)

In [4]:
def reconstruct_model_data_params(yaml_file_path: str):
    hparams_dict = load_hparams_from_yaml(yaml_file_path)

    # Try to get model_name and data_name from yaml:
    model_name = hparams_dict['model_name']
    data_name = hparams_dict['data_name']
    # Try to find the original model and data class by name:
    found_data_class = locate_and_import_class(data_name, 'datasets')
    found_model_class = locate_and_import_class(model_name, 'models')
    # Possible way of automatic loading args:
    # args = inspect.signature(found_data_class)
    # then access _parameter.ini and retrieve not set parameters

    hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')

    h_params = Namespace(**hparams_dict)

    # Let Datamodule pull what it wants
    datamodule = found_data_class.from_argparse_args(h_params)

    hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)

    return datamodule, found_model_class, hparams_dict

def gather_predictions_and_labels(model, data_option):
        preds = list()
        labels = list()
        filenames = list()
        with torch.no_grad():
            for file_name, x, y in datamodule.datasets[data_option]:
                preds.append(model(x.unsqueeze(0)).main_out)
                labels.append(y)
                filenames.append(file_name)
        labels = np.stack(labels).squeeze()
        preds = np.stack(preds).squeeze()
        return preds, labels, filenames

def build_tsne_dataframe(preds, labels):
    tsne = np.stack(TSNE().fit_transform(preds)).squeeze()
    tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])

    tsne_dataframe['labels'] = labels
    tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})
    return tsne_dataframe

def plot_scatterplot(data, option):
    p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)
    p.set_title(f'TSNE - distribution of logits for {option}')
    plt.show()

def redo_predictions(experiment_path, preds, fnames, data_class):
    sorted_y = defaultdict(list)
    for idx, (pred, fname) in enumerate(zip(preds, fnames)):
        sorted_y[fname].append(pred)
    sorted_y = dict(sorted_y)

    for file_name in sorted_y:
        sorted_y.update({file_name: np.stack(sorted_y[file_name])})


    if data_class.n_classes > 2:
        pred = np.stack(
            [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]
        ).squeeze()
        class_names = {val: key for val, key in
                       enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
    else:
        pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
        pred = np.stack(pred).squeeze()
        pred = np.where(pred > 0.5, 1, 0)
        class_names = {val: key for val, key in enumerate(['negative', 'positive'])}


    df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
                                prediction=[class_names[x.item()] for x in pred]))
    result_file = Path(experiment_path / 'predictions_new.csv')
    if result_file.exists():
        try:
            result_file.unlink()
        except:
            print('File already existed')
            pass
    with result_file.open(mode='wb') as csv_file:
        df.to_csv(index=False, path_or_buf=csv_file)


def re_valida(preds, labels, fnames, data_class):
    sorted_y = defaultdict(list)
    for idx, (pred, fname) in enumerate(zip(preds, fnames)):
        sorted_y[fname].append(pred)
    sorted_y = dict(sorted_y)

    for file_name in sorted_y:
        sorted_y.update({file_name: np.stack(sorted_y[file_name])})

    for key, val in list(sorted_y.items()):
        if val.ndim > 1:
            val = val.mean(axis=0)
            print(val.ndim)
            if not val[0] > 0.8:
                val[0] = 0
        sorted_y[key] = val

    pred = np.stack(
        [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]
    ).squeeze()

    one_hot_targets = np.eye(data_class.n_classes)[pred]

    # Sklearn Scores
    print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))


In [5]:
exp_path = _ROOT / out_path / model_name / exp_name / version
checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]
print(f'Selected Checkopint is {checkpoint}')
hparams_yaml = next(exp_path.glob('*.yaml'))
print(load_hparams_from_yaml(hparams_yaml)['data_name'])
# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE
datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())
# h_params.update(return_logits=True)

Selected Checkopint is output\VisualTransformer\VT_7899c07a4809a45c57cba58047cefb5e\version_7\ckpt_weights-v1.ckpt
PrimatesLibrosaDatamodule


In [6]:
model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()
datamodule.prepare_data()

RuntimeError: Error(s) in loading state_dict for VisualTransformer:
	size mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).

In [None]:
predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')

In [None]:
# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)
# plot_scatterplot(tsne_dataframe, data_option)

In [None]:
re_valida(predictions,labels_y, filenames, datamodule)