CCS intergration training running

notebooks
This commit is contained in:
Steffen 2021-03-24 08:03:12 +01:00
parent c12f3866c8
commit 82835295a1
11 changed files with 1264 additions and 445 deletions

0
datasets/__init__.py Normal file
View File

View File

@ -1,172 +1,19 @@
import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
from datasets.compare_base import CompareBase
from ml_lib.utils.tools import add_argparse_args
class CCSLibrosaDatamodule(_BaseDataModule):
class CCSLibrosaDatamodule(CompareBase):
@property
def class_names(self):
return {key: val for val, key in enumerate(['negative', 'positive'])}
class_names = ['negative', 'positive']
sub_dataset_name = 'ComParE2021_CCS'
@property
def n_classes(self):
return len(self.class_names)
def __init__(self, *args, **kwargs):
super(CCSLibrosaDatamodule, self).__init__(*args, **kwargs)
@property
def shape(self):
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(CompareBase, parent_parser)
@property
def mel_folder(self):
return self.root / 'mel_folder'
@property
def wav_folder(self):
return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
random_apply_chance=0.5, target_mel_length_in_seconds=1,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(CCSLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / 'ComParE2021_CCS'
self.mel_length_in_seconds = target_mel_length_in_seconds
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
# Utility
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
def _build_subdataset(self, row, build=False):
slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset, class_id, slice_file_name
def prepare_data(self, *args, **kwargs):
datasets = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
print(f'{data_option}-dataset prepared.')
self.datasets = datasets
return datasets
def setup(self, stag=None):
datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
dataset = list()
for row in all_rows:
mel_dataset, class_id, _ = self._build_subdataset(row)
dataset.append(mel_dataset)
print(f'{data_option}-dataset prepared!')
datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
##############################################################################
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
print(f'Dataset {self.__class__.__name__} setup done.')
return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()
@classmethod
def from_argparse_args(cls, args, **kwargs):
return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name)

181
datasets/compare_base.py Normal file
View File

@ -0,0 +1,181 @@
import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.tools import add_argparse_args
from ml_lib.utils.transforms import ToTensor
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
class CompareBase(_BaseDataModule):
@property
def class_names(self):
return {key: val for val, key in enumerate(self._class_names)}
@property
def n_classes(self):
return len(self.class_names)
@property
def shape(self):
return 1, int(self.mel_kwargs['n_mels']), int(self.sample_segment_length)
@property
def mel_folder(self):
return self.root / 'mel_folder'
@property
def wav_folder(self):
return self.root / 'wav'
def __init__(self, sub_dataset_name, class_names, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
random_apply_chance=0.5, target_mel_length_in_seconds=1,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(CompareBase, self).__init__()
self.sampler = sampler
self.samplers = None
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / sub_dataset_name
self._class_names = class_names
self.mel_length_in_seconds = target_mel_length_in_seconds
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
# Utility
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
def _build_subdataset(self, row, build=False):
slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
kwargs.update(sample_segment_len=self.sample_segment_length, sample_hop_len=self.sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset, class_id, slice_file_name
def manual_setup(self, stag=None):
datasets = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
print(f'{data_option}-dataset prepared.')
self.datasets = datasets
return datasets
def prepare_data(self, *args, rebuild=False, **kwargs):
datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(rebuild, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
print(f'{data_option}-dataset set up!')
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
##############################################################################
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
print(f'Dataset {self.__class__.__name__} setup done.')
return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()

View File

@ -1,170 +1,24 @@
import multiprocessing as mp
from collections import defaultdict
from pathlib import Path
from argparse import ArgumentParser, Namespace
from ctypes import Union
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from torchvision.transforms import Compose, RandomApply
from ml_lib.audio_toolset.audio_io import NormalizeLocal
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel
from ml_lib.utils.equal_sampler import EqualSampler
from ml_lib.utils.transforms import ToTensor
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
from datasets.compare_base import CompareBase
from ml_lib.utils.tools import add_argparse_args
class PrimatesLibrosaDatamodule(_BaseDataModule):
class PrimatesLibrosaDatamodule(CompareBase):
@property
def class_names(self):
return {key: val for val, key in enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
class_names = ['background', 'chimpanze', 'geunon', 'mandrille', 'redcap']
sub_dataset_name = 'primates'
@property
def n_classes(self):
return len(self.class_names)
def __init__(self, *args, **kwargs):
super(PrimatesLibrosaDatamodule, self).__init__(*args, **kwargs)
@property
def shape(self):
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(CompareBase, parent_parser)
return self.datasets[DATA_OPTION_train].datasets[0][0][1].shape
@property
def mel_folder(self):
return self.root / 'mel_folder'
@classmethod
def from_argparse_args(cls, args, **kwargs):
return CompareBase.from_argparse_args(args, class_names=cls.class_names, sub_dataset_name=cls.sub_dataset_name)
@property
def wav_folder(self):
return self.root / 'wav'
def __init__(self, data_root, batch_size, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
target_mel_length_in_seconds=0.7, random_apply_chance=0.5,
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3):
super(PrimatesLibrosaDatamodule, self).__init__()
self.sampler = sampler
self.samplers = None
self.num_worker = num_worker or 1
self.batch_size = batch_size
self.root = Path(data_root) / 'primates'
self.target_mel_length_in_seconds = target_mel_length_in_seconds
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
# Utility
self.utility_transforms = Compose([NormalizeLocal(), ToTensor()])
# Data Augmentations
self.random_apply_chance = random_apply_chance
self.mel_augmentations = Compose([
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
self.utility_transforms])
def train_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_train], num_workers=self.num_worker, pin_memory=True,
sampler=self.samplers[DATA_OPTION_train], batch_size=self.batch_size)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.datasets[DATA_OPTION_test], shuffle=False,
batch_size=self.batch_size, pin_memory=False,
num_workers=self.num_worker)
def _build_subdataset(self, row, build=False):
slice_file_name, class_name = row.strip().split(',')
class_id = self.class_names.get(class_name, -1)
audio_file_path = self.wav_folder / slice_file_name
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - Begin
kwargs = self.__dict__
if any([x in slice_file_name for x in [DATA_OPTION_devel, DATA_OPTION_test]]):
kwargs.update(mel_augmentations=self.utility_transforms)
# DATA OPTION DIFFERENTIATION !!!!!!!!!!! - End
target_frames = self.target_mel_length_in_seconds * self.mel_kwargs['sr']
sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
kwargs.update(sample_segment_len=sample_segment_length, sample_hop_len=sample_segment_length//2)
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, **kwargs)
if build:
assert mel_dataset.build_mel()
return mel_dataset, class_id, slice_file_name
def prepare_data(self, *args, **kwargs):
datasets = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
chunksize = len(all_rows) // max(self.num_worker, 1)
dataset = list()
with mp.Pool(processes=self.num_worker) as pool:
from itertools import repeat
results = pool.starmap_async(self._build_subdataset, zip(all_rows, repeat(True, len(all_rows))),
chunksize=chunksize)
for sub_dataset in results.get():
dataset.append(sub_dataset[0])
datasets[data_option] = ConcatDataset(dataset)
self.datasets = datasets
return datasets
def setup(self, stag=None):
datasets = dict()
samplers = dict()
weights = dict()
for data_option in data_options:
with open(Path(self.root) / 'lab' / f'{data_option}.csv', mode='r') as f:
# Exclude the header
_ = next(f)
all_rows = list(f)
dataset = list()
for row in all_rows:
mel_dataset, class_id, _ = self._build_subdataset(row)
dataset.append(mel_dataset)
datasets[data_option] = ConcatDataset(dataset)
# Build Weighted Sampler for train and val
if data_option in [DATA_OPTION_train]:
if self.sampler == EqualSampler.__name__:
class_idxs = [[idx for idx, (_, __, label) in enumerate(datasets[data_option]) if label == class_idx]
for class_idx in range(len(self.class_names))
]
samplers[data_option] = EqualSampler(class_idxs)
elif self.sampler == WeightedRandomSampler.__name__:
class_counts = defaultdict(lambda: 0)
for _, __, label in datasets[data_option]:
class_counts[label] += 1
len_largest_class = max(class_counts.values())
weights[data_option] = [1 / class_counts[x] for x in range(len(class_counts))]
##############################################################################
weights[data_option] = [weights[data_option][datasets[data_option][i][-1]]
for i in range(len(datasets[data_option]))]
samplers[data_option] = WeightedRandomSampler(weights[data_option],
len_largest_class * len(self.class_names))
else:
samplers[data_option] = None
self.datasets = datasets
self.samplers = samplers
return datasets
def purge(self):
import shutil
shutil.rmtree(self.mel_folder, ignore_errors=True)
print('Mel Folder has been recursively deleted')
print(f'Folder still exists: {self.mel_folder.exists()}')
return not self.mel_folder.exists()

23
main.py
View File

@ -2,6 +2,7 @@ from argparse import Namespace
import warnings
import yaml
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
@ -16,7 +17,7 @@ warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_callbacks=None):
def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, additional_callbacks=None):
fix_all_random_seeds(seed)
@ -54,16 +55,23 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca
# =============================================================================
# Let Datamodule pull what it wants
datamodule = data_class.from_argparse_args(h_params)
datamodule.setup()
# Final h_params Setup:
h_params = vars(h_params)
h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes)
h_params = Namespace(**h_params)
# Let Trainer pull what it wants and add callbacks
trainer = Trainer.from_argparse_args(h_params, logger=logger, callbacks=callbacks)
# Let Model pull what it wants
model = model_class.from_argparse_args(h_params, in_shape=datamodule.shape, n_classes=datamodule.n_classes)
model = model_class.from_argparse_args(h_params)
model.init_weights()
# trainer.test(model=model, datamodule=datamodule)
# Store Model in Object File:
model.save_to_disk(logger.save_dir)
# Store h_params to yaml_file File & Neptune (if available):
logger.log_hyperparams(h_params)
trainer.fit(model, datamodule)
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
@ -73,10 +81,9 @@ def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_ca
except:
print('Test did not Suceed!')
pass
try:
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
except:
print('debug max_score_logging')
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
return score_callback.best_scores['PL_recall_score']

View File

@ -7,7 +7,7 @@ from torch import nn
from einops import rearrange, repeat
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.blocks import (TransformerModule, F_x)
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
from util.module_mixins import CombinedModelMixins
@ -21,7 +21,8 @@ class VisualTransformer(CombinedModelMixins,
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
return_logits=False):
# TODO: Move this to parent class, or make it much easier to access... But How...
a = dict(locals())
@ -69,14 +70,20 @@ class VisualTransformer(CombinedModelMixins,
self.to_cls_token = nn.Identity()
logits = self.params.n_classes if self.params.n_classes > 2 else 1
if return_logits:
outbound_activation = nn.Identity()
else:
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
self.params.activation(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, logits),
nn.Softmax() if logits > 1 else nn.Sigmoid()
outbound_activation
)
def forward(self, x, mask=None, return_attn_weights=False):

File diff suppressed because one or more lines are too long

397
notebooks/Train Eval.ipynb Normal file
View File

@ -0,0 +1,397 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"from collections import defaultdict\n",
"from pathlib import Path\n",
"from natsort import natsorted\n",
"from pytorch_lightning.core.saving import ModelIO\n",
"from ml_lib.utils.model_io import SavedLightningModels\n",
"from ml_lib.utils.tools import locate_and_import_class\n",
"\n",
"import yaml\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import pytorch_lightning as pl\n",
"import librosa\n",
"import pandas as pd\n",
"import variables as v\n",
"import seaborn as sns\n",
"from tqdm import tqdm\n",
"from matplotlib import pyplot as plt"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Imports go here\n"
}
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"# Settings and Variables\n",
"\n",
"# This Experiment (= Model and Parameter Configuration\n",
"_ROOT = Path('..')\n",
"out_path = Path('..') / Path('output')\n",
"model_name = 'VisualTransformer'\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 42,
"outputs": [],
"source": [
"def print_stats(data_option, mean_duration, std_duration, min_duration, max_duration):\n",
" print(f'For {data_option}; statistics are:')\n",
" print(f'Scores - mean: {mean_duration:.3f}s\\tstd: {std_duration:.3f}s'\n",
" f'min: {min_duration:.3f}s\\t max: {max_duration:.3f}s')\n",
"\n",
"def print_metrics(exp_path):\n",
" print(f'--------------{exp_path.name}------------------')\n",
" best_scores = []\n",
" had_errors = []\n",
" for run_folder in [x for x in exp_path.iterdir() if x.is_dir()]:\n",
" # model_class = locate_and_import_class(model_name, 'models')\n",
" # sorted_checkpoints = natsorted(run_folder.glob('*.ckpt'))\n",
" # model = ModelIO.load_from_checkpoint(str(sorted_checkpoints[0]), strict=True)\n",
" try:\n",
" metrics = pd.read_csv(run_folder / 'metrics.csv')\n",
"\n",
" # Possible keys are:\n",
" # -- CE - Losses:\n",
" # val_max_vote_loss, val_mean_vote_loss, mean_val_loss\n",
" # -- Fallback:\n",
" # mean_loss,epoch,step,macro_f1_score, macro_roc_auc_ovr, uar_score, micro_f1_score\n",
" # Pytorch Metrics:\n",
" # PL_f1_score,PL_accuracy_score_score, PL_fbeta_score,PL_recall_score,PL_precision_score,\n",
" score = metrics.PL_recall_score[-1]\n",
" print(f'{exp_path.name} - {run_folder.name}: {score}')\n",
" best_scores.append(score)\n",
" had_errors.append(False)\n",
" except (AttributeError, FileNotFoundError):\n",
" had_errors.append(True)\n",
" pass\n",
" if any(had_errors):\n",
" return\n",
" else:\n",
" print('\\n')\n",
" stats = np.mean(best_scores), np.std(best_scores), np.min(best_scores), np.max(best_scores)\n",
" print_stats(exp_path.name, *stats)\n",
" print('--------------------------------------------')\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Util Functions\n"
}
}
},
{
"cell_type": "code",
"execution_count": 32,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------VT_259ee495ee2d2dc0e56bb23d12476f17------------------\n",
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_1: 0.8403531908988953\n",
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_3: 0.8312729001045227\n",
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_0: 0.8342075347900391\n",
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_5: 0.8459098935127258\n",
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_2: 0.8468937277793884\n",
"VT_259ee495ee2d2dc0e56bb23d12476f17 - version_4: 0.8404075503349304\n",
"\n",
"\n",
"For VT_259ee495ee2d2dc0e56bb23d12476f17; statistics are:\n",
"Scores - mean: 0.840s\tstd: 0.006smin: 0.831s\t max: 0.847s\n",
"--------------------------------------------\n",
"--------------VT_012aff7c1c667073aedafcbebfa35ec7------------------\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_6: 0.8637051582336426\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_1: 0.864475429058075\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_3: 0.854859471321106\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_0: 0.8631429672241211\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_8: 0.8484407663345337\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_5: 0.8564963340759277\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_7: 0.8519455194473267\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_2: 0.8683117032051086\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_9: 0.8730489611625671\n",
"VT_012aff7c1c667073aedafcbebfa35ec7 - version_4: 0.8658838272094727\n",
"\n",
"\n",
"For VT_012aff7c1c667073aedafcbebfa35ec7; statistics are:\n",
"Scores - mean: 0.861s\tstd: 0.007smin: 0.848s\t max: 0.873s\n",
"--------------------------------------------\n",
"--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_6: 0.854997456073761\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_1: 0.8609604835510254\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_3: 0.8558254837989807\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_0: 0.8728921413421631\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_8: 0.8631933927536011\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_5: 0.8612215518951416\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_7: 0.8661960959434509\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_2: 0.8636621832847595\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_9: 0.8614727258682251\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_4: 0.8657329082489014\n",
"\n",
"\n",
"For VT_fdf2a86085b508c1325b181c830a4cf7; statistics are:\n",
"Scores - mean: 0.863s\tstd: 0.005smin: 0.855s\t max: 0.873s\n",
"--------------------------------------------\n",
"--------------VT_cc64c06847a7ca26f5ea4d465f9cc5bc------------------\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_6: 0.8572231531143188\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_1: 0.8442623615264893\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_3: 0.8498414754867554\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_0: 0.8569087982177734\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_8: 0.8455194234848022\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_5: 0.8435630798339844\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_7: 0.845982551574707\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_2: 0.8571171164512634\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_9: 0.8448543548583984\n",
"VT_cc64c06847a7ca26f5ea4d465f9cc5bc - version_4: 0.845399022102356\n",
"\n",
"\n",
"For VT_cc64c06847a7ca26f5ea4d465f9cc5bc; statistics are:\n",
"Scores - mean: 0.849s\tstd: 0.005smin: 0.844s\t max: 0.857s\n",
"--------------------------------------------\n",
"--------------VT_2c7afd50e127f5a2339db0ddfd6bfd7c------------------\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_6: 0.8630585670471191\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_1: 0.8686699271202087\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_3: 0.8729345798492432\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_0: 0.8636038899421692\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_8: 0.8558077812194824\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_5: 0.8710847496986389\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_7: 0.8619015216827393\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_2: 0.8499867916107178\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_9: 0.8507344722747803\n",
"VT_2c7afd50e127f5a2339db0ddfd6bfd7c - version_4: 0.8555077314376831\n",
"\n",
"\n",
"For VT_2c7afd50e127f5a2339db0ddfd6bfd7c; statistics are:\n",
"Scores - mean: 0.861s\tstd: 0.008smin: 0.850s\t max: 0.873s\n",
"--------------------------------------------\n",
"--------------VT_63b9fee765cdda91756af1f35cd320a3------------------\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_6: 0.8663593530654907\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_1: 0.8519773483276367\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_3: 0.8519774675369263\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_0: 0.8603388071060181\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_8: 0.8614517450332642\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_5: 0.8558711409568787\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_7: 0.8537712097167969\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_2: 0.8558205962181091\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_9: 0.8647329211235046\n",
"VT_63b9fee765cdda91756af1f35cd320a3 - version_4: 0.8546129465103149\n",
"\n",
"\n",
"For VT_63b9fee765cdda91756af1f35cd320a3; statistics are:\n",
"Scores - mean: 0.858s\tstd: 0.005smin: 0.852s\t max: 0.866s\n",
"--------------------------------------------\n",
"--------------VT_aca900a5b9566af61c91aea6525190e6------------------\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_6: 0.8575441241264343\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_1: 0.8453981280326843\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_3: 0.8621359467506409\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_0: 0.8547767400741577\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_8: 0.8613359928131104\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_5: 0.8667657375335693\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_7: 0.8474754095077515\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_2: 0.8628634214401245\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_9: 0.8585749268531799\n",
"VT_aca900a5b9566af61c91aea6525190e6 - version_4: 0.8380126357078552\n",
"\n",
"\n",
"For VT_aca900a5b9566af61c91aea6525190e6; statistics are:\n",
"Scores - mean: 0.855s\tstd: 0.009smin: 0.838s\t max: 0.867s\n",
"--------------------------------------------\n",
"--------------VT_fb6b96a190455106d29f0630f002ac6f------------------\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_6: 0.8635155558586121\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_1: 0.8261691927909851\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_3: 0.8444902896881104\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_0: 0.865719735622406\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_8: 0.8533784747123718\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_5: 0.8555656671524048\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_7: 0.837948739528656\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_2: 0.8545827865600586\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_9: 0.8541560769081116\n",
"VT_fb6b96a190455106d29f0630f002ac6f - version_4: 0.85297691822052\n",
"\n",
"\n",
"For VT_fb6b96a190455106d29f0630f002ac6f; statistics are:\n",
"Scores - mean: 0.851s\tstd: 0.011smin: 0.826s\t max: 0.866s\n",
"--------------------------------------------\n",
"--------------VT_378971720b930050ad7662bb96699e20------------------\n",
"VT_378971720b930050ad7662bb96699e20 - version_6: 0.8388294577598572\n",
"VT_378971720b930050ad7662bb96699e20 - version_1: 0.8333806395530701\n",
"VT_378971720b930050ad7662bb96699e20 - version_3: 0.847841203212738\n",
"VT_378971720b930050ad7662bb96699e20 - version_0: 0.8287097811698914\n",
"VT_378971720b930050ad7662bb96699e20 - version_8: 0.8436978459358215\n",
"VT_378971720b930050ad7662bb96699e20 - version_5: 0.8392724990844727\n",
"VT_378971720b930050ad7662bb96699e20 - version_7: 0.8410612344741821\n",
"VT_378971720b930050ad7662bb96699e20 - version_2: 0.8407015204429626\n",
"VT_378971720b930050ad7662bb96699e20 - version_9: 0.8334627151489258\n",
"VT_378971720b930050ad7662bb96699e20 - version_4: 0.8400266766548157\n",
"\n",
"\n",
"For VT_378971720b930050ad7662bb96699e20; statistics are:\n",
"Scores - mean: 0.839s\tstd: 0.005smin: 0.829s\t max: 0.848s\n",
"--------------------------------------------\n",
"--------------VT_d55f1492ff29a3cd1026013948ce7fa7------------------\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_6: 0.8385945558547974\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_1: 0.8324360251426697\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_3: 0.8386826515197754\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_0: 0.8366813063621521\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_8: 0.8460721969604492\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_5: 0.8374781608581543\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_7: 0.8320286273956299\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_2: 0.8370164632797241\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_9: 0.8495808839797974\n",
"VT_d55f1492ff29a3cd1026013948ce7fa7 - version_4: 0.8332125544548035\n",
"\n",
"\n",
"For VT_d55f1492ff29a3cd1026013948ce7fa7; statistics are:\n",
"Scores - mean: 0.838s\tstd: 0.005smin: 0.832s\t max: 0.850s\n",
"--------------------------------------------\n",
"--------------VT_15cbb349b2b50dbb97beec16af2bedab------------------\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_6: 0.8407894372940063\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_1: 0.836580216884613\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_3: 0.8312996029853821\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_0: 0.8336991667747498\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_8: 0.8231534957885742\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_5: 0.8243923187255859\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_7: 0.8342592120170593\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_2: 0.8349334001541138\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_9: 0.8382810950279236\n",
"VT_15cbb349b2b50dbb97beec16af2bedab - version_4: 0.8381868600845337\n",
"\n",
"\n",
"For VT_15cbb349b2b50dbb97beec16af2bedab; statistics are:\n",
"Scores - mean: 0.834s\tstd: 0.006smin: 0.823s\t max: 0.841s\n",
"--------------------------------------------\n"
]
}
],
"source": [
"for model_configuration in [x for x in (out_path / model_name).iterdir() if x.is_dir()]:\n",
" # Print metrics\n",
" print_metrics(model_configuration)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Mass - Load Model and read Metrics\n"
}
}
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n",
"--------------VT_fdf2a86085b508c1325b181c830a4cf7------------------\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_6: 0.854997456073761\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_1: 0.8609604835510254\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_3: 0.8558254837989807\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_0: 0.8728921413421631\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_8: 0.8631933927536011\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_5: 0.8612215518951416\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_7: 0.8661960959434509\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_2: 0.8636621832847595\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_9: 0.8614727258682251\n",
"VT_fdf2a86085b508c1325b181c830a4cf7 - version_4: 0.8657329082489014\n",
"--------------------------------------------\n",
"--------------------------------------------\n"
]
}
],
"source": [
"# fingerprint = '012aff7c1c667073aedafcbebfa35ec7'\n",
"fingerprint = 'fdf2a86085b508c1325b181c830a4cf7'\n",
"exp_name = f'{\"\".join([x for x in model_name if x.isupper()])}_{fingerprint}'\n",
"\n",
"# Print metrics\n",
"print_metrics(out_path/model_name/exp_name)\n",
"\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Single - Load Model and read Metrics\n"
}
}
},
{
"cell_type": "code",
"execution_count": 39,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" filenames prediction prediction_named\n",
"0 test_00001 1 chimpanze\n",
"1 test_00002 0 background\n",
"2 test_00003 0 background\n",
"3 test_00004 1 chimpanze\n",
"4 test_00005 4 redcap\n"
]
}
],
"source": [
"predictions_file = out_path/model_name/'VT_15cbb349b2b50dbb97beec16af2bedab'/'version_9'/'predictions.csv'\n",
"df_predictions = pd.read_csv(predictions_file)\n",
"print(df_predictions.head())\n",
"df_predictions = df_predictions[['filenames', 'prediction_named']]\n",
"df_predictions.columns = ['filename', 'prediction']\n",
"df_predictions['filename'] = df_predictions['filename'] + '.wav'\n",
"predictions_file_new = predictions_file.parent / 'prediction_final.csv'\n",
"df_predictions.to_csv(index=False, path_or_buf=predictions_file_new)\n",
"\n",
"\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% Combine Predictions#\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,104 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%% IMPORTS\n"
}
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from natsort import natsorted\n",
"from pytorch_lightning.core.saving import *\n",
"from ml_lib.utils.model_io import SavedLightningModels\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"outputs": [],
"source": [
"from ml_lib.utils.tools import locate_and_import_class\n",
"from models.transformer_model import VisualTransformer\n",
"_ROOT = Path('..')\n",
"out_path = 'output'\n",
"model_class = VisualTransformer\n",
"model_name = model_class.name()\n",
"\n",
"exp_name = 'VT_01123c93daaffa92d2ed341bda32426d'\n",
"version = 'version_2'"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%M Path resolving and variables\n"
}
}
},
{
"cell_type": "code",
"execution_count": 50,
"outputs": [
{
"ename": "ValueError",
"evalue": "When you set `reduce` as 'macro', you have to provide the number of classes.",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
"\u001B[1;32m<ipython-input-50-0216292a172f>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 6\u001B[0m \u001B[0madditional_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mvariable_length\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mFalse\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mc_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;36m5\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 7\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 8\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhparams_file\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstr\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mhparams_yaml\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0madditional_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 9\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 196\u001B[0m \u001B[0m_cls_kwargs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0mk\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mv\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mk\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mcls_init_args_name\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 197\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 198\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m**\u001B[0m\u001B[0m_cls_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 199\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 200\u001B[0m \u001B[1;31m# give model a chance to load something\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m~\\projects\\compare_21\\models\\transformer_model.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval)\u001B[0m\n\u001B[0;32m 27\u001B[0m \u001B[0ma\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mlocals\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 28\u001B[0m \u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m{\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0ma\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0marg\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0marg\u001B[0m \u001B[1;32min\u001B[0m \u001B[0minspect\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msignature\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparameters\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mkeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0marg\u001B[0m \u001B[1;33m!=\u001B[0m \u001B[1;34m'self'\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 29\u001B[1;33m \u001B[0msuper\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mVisualTransformer\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0min_shape\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0min_shape\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, model_parameters, weight_init)\u001B[0m\n\u001B[0;32m 112\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mModelParameters\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel_parameters\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 113\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 114\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mPLMetrics\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtag\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'PL'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 115\u001B[0m \u001B[1;32mpass\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 116\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32m~\\projects\\compare_21\\ml_lib\\modules\\util.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, n_classes, tag)\u001B[0m\n\u001B[0;32m 30\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 31\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0maccuracy_score\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mAccuracy\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 32\u001B[1;33m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprecision\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mPrecision\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 33\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mrecall\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mRecall\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnum_classes\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'macro'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconfusion_matrix\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpl\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmetrics\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mConfusionMatrix\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_classes\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnormalize\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'true'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mcompute_on_step\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mFalse\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\precision_recall.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, num_classes, threshold, average, multilabel, mdmc_average, ignore_index, top_k, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 139\u001B[0m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34mf\"The `average` has to be one of {allowed_average}, got {average}.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 140\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 141\u001B[1;33m super().__init__(\n\u001B[0m\u001B[0;32m 142\u001B[0m \u001B[0mreduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0maverage\u001B[0m \u001B[1;32min\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;34m\"weighted\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"none\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32melse\u001B[0m \u001B[0maverage\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 143\u001B[0m \u001B[0mmdmc_reduce\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mmdmc_average\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\metrics\\classification\\stat_scores.py\u001B[0m in \u001B[0;36m__init__\u001B[1;34m(self, threshold, top_k, reduce, num_classes, ignore_index, mdmc_reduce, is_multiclass, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)\u001B[0m\n\u001B[0;32m 157\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mreduce\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;34m\"macro\"\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m<\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 159\u001B[1;33m \u001B[1;32mraise\u001B[0m \u001B[0mValueError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m\"When you set `reduce` as 'macro', you have to provide the number of classes.\"\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 160\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 161\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mand\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32mand\u001B[0m \u001B[1;33m(\u001B[0m\u001B[1;32mnot\u001B[0m \u001B[1;36m0\u001B[0m \u001B[1;33m<=\u001B[0m \u001B[0mignore_index\u001B[0m \u001B[1;33m<\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mnum_classes\u001B[0m \u001B[1;33m==\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
"\u001B[1;31mValueError\u001B[0m: When you set `reduce` as 'macro', you have to provide the number of classes."
]
}
],
"source": [
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-1]\n",
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
"\n",
"hparams_file = load_hparams_from_yaml(hparams_yaml)\n",
"additional_kwargs = dict(variable_length = False, c_classes=5)\n",
"\n",
"model = model_class.load_from_checkpoint(checkpoint, hparams_file=str(hparams_yaml), **additional_kwargs)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

247
reload model.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -18,7 +18,7 @@ class TrainMixin:
batch_files, batch_x, batch_y = batch_xy
y = self(batch_x).main_out
if self.params.n_classes <= 2:
loss = self.bce_loss(y, batch_y.long())
loss = self.bce_loss(y.squeeze().float(), batch_y.float())
else:
if self.params.loss == 'focal_loss_rob':
labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)