paper preperations and notebooks, optuna callbacks
This commit is contained in:
+15
-4
@@ -10,7 +10,7 @@ data_name = CCSLibrosaDatamodule
|
||||
[data]
|
||||
num_worker = 10
|
||||
data_root = data
|
||||
variable_length = True
|
||||
variable_length = False
|
||||
target_mel_length_in_seconds = 0.7
|
||||
|
||||
n_mels = 128
|
||||
@@ -49,6 +49,15 @@ dropout = 0.2
|
||||
lat_dim = 32
|
||||
filters = [16, 32, 64, 128]
|
||||
|
||||
[BandwiseConvClassifier]
|
||||
weight_init = xavier_normal_
|
||||
activation = gelu
|
||||
use_bias = True
|
||||
use_norm = True
|
||||
dropout = 0.2
|
||||
lat_dim = 32
|
||||
filters = [16, 32, 64, 128]
|
||||
|
||||
[VisualTransformer]
|
||||
weight_init = xavier_normal_
|
||||
activation = gelu
|
||||
@@ -73,11 +82,13 @@ use_norm = True
|
||||
use_residual = True
|
||||
dropout = 0.2
|
||||
|
||||
lat_dim = 32
|
||||
mlp_dim = 6
|
||||
lat_dim = 6
|
||||
head_dim = 6
|
||||
patch_size = 8
|
||||
attn_depth = 12
|
||||
attn_depth = 6
|
||||
heads = 4
|
||||
embedding_size = 128
|
||||
embedding_size = 30
|
||||
|
||||
[HorizontalVisualTransformer]
|
||||
weight_init = xavier_normal_
|
||||
|
||||
@@ -118,7 +118,9 @@ class CompareBase(_BaseDataModule):
|
||||
lab_file = None
|
||||
|
||||
for data_option in data_options:
|
||||
if lab_file is not None:
|
||||
if lab_file is None:
|
||||
lab_file = f'{data_option}.csv'
|
||||
elif lab_file is not None:
|
||||
if any([x in lab_file for x in data_options]):
|
||||
lab_file = f'{data_option}.csv'
|
||||
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
|
||||
|
||||
@@ -76,15 +76,12 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad
|
||||
trainer.fit(model, datamodule)
|
||||
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
||||
|
||||
trainer.test(model=model, datamodule=datamodule, ckpt_path='best')
|
||||
|
||||
trainer.test(model=model, datamodule=datamodule)
|
||||
#except:
|
||||
# print('Test did not Suceed!')
|
||||
# pass
|
||||
|
||||
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
||||
|
||||
return score_callback.best_scores['PL_recall_score']
|
||||
return Namespace(model=model,
|
||||
best_model_path=ckpt_callback.best_model_path,
|
||||
best_model_score=ckpt_callback.best_model_score.item(),
|
||||
max_score_monitor=score_callback.best_scores)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -96,6 +93,6 @@ if __name__ == '__main__':
|
||||
|
||||
# Start
|
||||
# -----------------
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)
|
||||
print(run_lightning_loop(hparams, found_data_class, found_model_class, found_seed))
|
||||
print('done')
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
|
||||
class BandwiseConvClassifier(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
):
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
use_bias, use_norm, dropout, lat_dim, filters,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
|
||||
loss, scheduler, lr_scheduler_parameter
|
||||
):
|
||||
# TODO: Move this to parent class, or make it much easieer to access....
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(BandwiseConvClassifier, self).__init__(params)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_band_sections = 8
|
||||
|
||||
# Modules
|
||||
# =============================================================================
|
||||
self.split = Splitter(in_shape, self.n_band_sections)
|
||||
|
||||
k = 3
|
||||
self.band_list = ModuleList()
|
||||
for band in range(self.n_band_sections):
|
||||
last_shape = self.split.shape[band]
|
||||
conv_list = ModuleList()
|
||||
for conv_filters in self.params.filters:
|
||||
conv_list.append(ConvModule(last_shape, conv_filters, (k, k), conv_stride=(2, 2), conv_padding=2,
|
||||
**self.params.module_kwargs))
|
||||
last_shape = conv_list[-1].shape
|
||||
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
|
||||
# last_shape = self.conv_list[-1].shape
|
||||
self.band_list.append(conv_list)
|
||||
|
||||
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||
|
||||
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||
self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||
|
||||
# Make Decision between binary and Multiclass Classification
|
||||
logits = n_classes if n_classes > 2 else 1
|
||||
module_kwargs = self.params.module_kwargs
|
||||
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||
self.full_out = LinearModule(self.full_2.shape, logits, **module_kwargs)
|
||||
|
||||
def forward(self, batch, **kwargs):
|
||||
tensors = self.split(batch)
|
||||
for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
|
||||
for conv in convs:
|
||||
tensor = conv(tensor)
|
||||
tensors[idx] = tensor
|
||||
|
||||
tensor = self.merge(tensors)
|
||||
tensor = self.full_1(tensor)
|
||||
tensor = self.full_2(tensor)
|
||||
tensor = self.full_out(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
@@ -3,8 +3,6 @@ from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import LinearModule
|
||||
from ml_lib.modules.model_parts import CNNEncoder
|
||||
from ml_lib.modules.util import (LightningBaseModule)
|
||||
@@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
|
||||
|
||||
tensor = self.classifier(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
@@ -130,8 +130,6 @@ try:
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
except ImportError: # pragma: do not provide model class
|
||||
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
||||
|
||||
@@ -109,9 +109,3 @@ class Tester(CombinedModelMixins,
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=None)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
@@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
if return_logits:
|
||||
outbound_activation = nn.Identity()
|
||||
outbound_activation = nn.Identity
|
||||
else:
|
||||
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
@@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes <= 2:
|
||||
return BinaryScores(self)(outputs)
|
||||
else:
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
@@ -116,5 +116,3 @@ class HorizontalVisualTransformer(CombinedModelMixins,
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
@@ -18,9 +18,10 @@ MIN_NUM_PATCHES = 16
|
||||
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||
return_logits=False):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
@@ -47,14 +48,6 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
@@ -62,10 +55,11 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
head_dim=self.params.head_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
@@ -74,13 +68,17 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, self.n_classes),
|
||||
nn.Softmax()
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@@ -112,5 +110,3 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
+40
-21
@@ -10,35 +10,54 @@ import itertools
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Set new values
|
||||
hparams_dict = dict(seed=range(1, 11),
|
||||
model_name=['CNNBaseline'],
|
||||
data_name=['CCSLibrosaDatamodule'], # 'CCSLibrosaDatamodule'],
|
||||
batch_size=[50],
|
||||
max_epochs=[200],
|
||||
variable_length=[False], # THIS IS NEXT
|
||||
target_mel_length_in_seconds=[0.7],
|
||||
random_apply_chance=[0.5], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
noise_ratio=[0.3], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||
mask_ratio=[0.3], # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
||||
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
hparams_dict = dict(seed=range(1, 6),
|
||||
# BandwiseConvClassifier, CNNBaseline, VisualTransformer, VerticalVisualTransformer
|
||||
model_name=['VisualTransformer'],
|
||||
# CCSLibrosaDatamodule, PrimatesLibrosaDatamodule,
|
||||
data_name=['PrimatesLibrosaDatamodule'],
|
||||
batch_size=[30],
|
||||
max_epochs=[150],
|
||||
target_mel_length_in_seconds=[0.5],
|
||||
outpath=['head_exp'],
|
||||
|
||||
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
mlp_dim=[16], # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||
attn_depth=[12], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||
heads=[6], # trial.suggest_int('heads', 2, 16, step=2),
|
||||
|
||||
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
||||
lr_scheduler_parameter=[0.94, 0.93, 0.95], # [0.98],
|
||||
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||
lr_scheduler_parameter=[0.95], # [0.95],
|
||||
|
||||
loss=['ce_loss'],
|
||||
sampler=['WeightedRandomSampler'],
|
||||
# trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
||||
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
|
||||
)
|
||||
|
||||
# Data Aug Parameters
|
||||
hparams_dict.update(random_apply_chance=[0.3], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=[0.2], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
noise_ratio=[0.4], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||
mask_ratio=[0.2], # triaSl.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),)
|
||||
)
|
||||
if False:
|
||||
# CNN Parameters:
|
||||
hparams_dict.update(filters=[[16, 32, 64, 32]],
|
||||
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
variable_length=[False], # THIS does not Work
|
||||
lat_dim=[64], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
)
|
||||
else:
|
||||
# Transfornmer Parameters:
|
||||
hparams_dict.update(lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
mlp_dim=[16],
|
||||
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||
attn_depth=[14], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||
heads=[2,4,6,8,10], # trial.suggest_int('heads', 2, 16, step=2),
|
||||
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||
variable_length=[False], # THIS does not Work
|
||||
)
|
||||
|
||||
keys, values = zip(*hparams_dict.items())
|
||||
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+397
-636
File diff suppressed because one or more lines are too long
+68
-22
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+64
-15
@@ -1,6 +1,7 @@
|
||||
import pickle
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import optuna as optuna
|
||||
from optuna.integration import PyTorchLightningPruningCallback
|
||||
@@ -8,6 +9,56 @@ from optuna.integration import PyTorchLightningPruningCallback
|
||||
from main import run_lightning_loop
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
|
||||
|
||||
class ContiniousSavingCallback:
|
||||
|
||||
@property
|
||||
def study(self):
|
||||
return self._study
|
||||
|
||||
@property
|
||||
def tmp_study_path(self):
|
||||
return Path(self.root) / f'TMP_{self.study.study_name}_trial{self.study.trials[-1].number}.pkl'
|
||||
|
||||
@property
|
||||
def final_study_path(self):
|
||||
return Path(self.root) / f'FINAL_{self.study.study_name}_' \
|
||||
f'best_{self.study.best_trial.number}_' \
|
||||
f'score_{self.study.best_value}.pkl'
|
||||
|
||||
def __init__(self, root:Union[str, Path], study: optuna.Study):
|
||||
self._study = study
|
||||
self.root = Path(root)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _write_to_disk(object, path):
|
||||
path = Path(path)
|
||||
path.parent.mkdir(exist_ok=True)
|
||||
if path.exists():
|
||||
path.unlink(missing_ok=True)
|
||||
with path.open(mode='wb') as f:
|
||||
pickle.dump(object, f)
|
||||
|
||||
def save_final(self):
|
||||
self._write_to_disk(self.study, self.final_study_path())
|
||||
|
||||
def clean_up(self):
|
||||
temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*')
|
||||
for temp_study_file in temp_study_files:
|
||||
temp_study_file.unlink(missing_ok=True)
|
||||
|
||||
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
|
||||
self._write_to_disk(study, self.tmp_study_path(trial.number))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.save_final()
|
||||
self.clean_up()
|
||||
|
||||
|
||||
def optimize(trial: optuna.Trial):
|
||||
# Optuna configuration
|
||||
folder = Path('study')
|
||||
@@ -19,10 +70,10 @@ def optimize(trial: optuna.Trial):
|
||||
lr_scheduler_parameter = None
|
||||
|
||||
optuna_suggestions = dict(
|
||||
model_name='CNNBaseline',
|
||||
data_name='MaskLibrosaDatamodule',
|
||||
model_name='VisualTransformer',
|
||||
data_name='CCSLibrosaDatamodule',
|
||||
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
||||
max_epochs=75,
|
||||
max_epochs=200,
|
||||
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
||||
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
@@ -44,7 +95,7 @@ def optimize(trial: optuna.Trial):
|
||||
for layer_idx in range(model_depth):
|
||||
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
||||
optuna_suggestions.update(filters=filters)
|
||||
elif optuna_suggestions['model_name'] == 'VisualTransformer':
|
||||
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
|
||||
transformer_dict = dict(
|
||||
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
@@ -62,8 +113,10 @@ def optimize(trial: optuna.Trial):
|
||||
'_parameters.ini', overrides=optuna_suggestions)
|
||||
h_params = Namespace(**h_params)
|
||||
try:
|
||||
best_score = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
||||
results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
||||
additional_callbacks=pruning_callback, seed=seed)
|
||||
best_score = results.best_model_score
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
best_score = 0
|
||||
@@ -71,22 +124,18 @@ def optimize(trial: optuna.Trial):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
||||
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
||||
study.optimize(optimize, n_trials=100)
|
||||
optuna_study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
||||
with ContiniousSavingCallback('study', optuna_study) as continious_save_callback:
|
||||
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
||||
optuna_study.optimize(optimize, n_trials=200, show_progress_bar=True, callbacks=[continious_save_callback])
|
||||
|
||||
print("Number of finished trials: {}".format(len(study.trials)))
|
||||
print("Number of finished trials: {}".format(len(optuna_study.trials)))
|
||||
|
||||
print("Best trial:")
|
||||
trial = study.best_trial
|
||||
trial = optuna_study.best_trial
|
||||
|
||||
print(" Value: {}".format(trial.value))
|
||||
|
||||
print(" Params: ")
|
||||
for key, value in trial.params.items():
|
||||
print(" {}: {}".format(key, value))
|
||||
|
||||
optuna_study_object = Path('study') / 'study.pkl'
|
||||
optuna_study_object.parent.mkdir(exist_ok=True)
|
||||
with optuna_study_object.open(mode='wb') as f:
|
||||
pickle.dump(study, f)
|
||||
|
||||
+1
-1
@@ -23,7 +23,7 @@ def rebuild_dataset(h_params, data_class):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dataset in [MaskLibrosaDatamodule]: # [PrimatesLibrosaDatamodule, CCSLibrosaDatamodule]:
|
||||
for dataset in [CCSLibrosaDatamodule]:
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, _, _, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
||||
|
||||
|
||||
@@ -0,0 +1,341 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"pycharm": {
|
||||
"name": "#%% IMPORTS\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import defaultdict\n",
|
||||
"from pathlib import Path\n",
|
||||
"from natsort import natsorted\n",
|
||||
"from pytorch_lightning.core.saving import *\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from sklearn.manifold import TSNE\n",
|
||||
"\n",
|
||||
"import seaborn as sns\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ml_lib.metrics.binary_class_classifictaion import BinaryScores\n",
|
||||
"from ml_lib.utils.tools import locate_and_import_class\n",
|
||||
"_ROOT = Path()\n",
|
||||
"out_path = 'output'\n",
|
||||
"model_name = 'VisualTransformer'\n",
|
||||
"exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'\n",
|
||||
"version = 'version_7'"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%M Path resolving and variables\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.style.use('default')\n",
|
||||
"sns.set_palette('Dark2')\n",
|
||||
"\n",
|
||||
"tex_fonts = {\n",
|
||||
" # Use LaTeX to write all text\n",
|
||||
" \"text.usetex\": True,\n",
|
||||
" \"font.family\": \"serif\",\n",
|
||||
" # Use 10pt font in plots, to match 10pt font in document\n",
|
||||
" \"axes.labelsize\": 10,\n",
|
||||
" \"font.size\": 10,\n",
|
||||
" # Make the legend/label fonts a little smaller\n",
|
||||
" \"legend.fontsize\": 8,\n",
|
||||
" \"xtick.labelsize\": 8,\n",
|
||||
" \"ytick.labelsize\": 8\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# plt.rcParams.update(tex_fonts)\n",
|
||||
"\n",
|
||||
"Path('figures').mkdir(exist_ok=True)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% Seaborn Settings\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def reconstruct_model_data_params(yaml_file_path: str):\n",
|
||||
" hparams_dict = load_hparams_from_yaml(yaml_file_path)\n",
|
||||
"\n",
|
||||
" # Try to get model_name and data_name from yaml:\n",
|
||||
" model_name = hparams_dict['model_name']\n",
|
||||
" data_name = hparams_dict['data_name']\n",
|
||||
" # Try to find the original model and data class by name:\n",
|
||||
" found_data_class = locate_and_import_class(data_name, 'datasets')\n",
|
||||
" found_model_class = locate_and_import_class(model_name, 'models')\n",
|
||||
" # Possible way of automatic loading args:\n",
|
||||
" # args = inspect.signature(found_data_class)\n",
|
||||
" # then access _parameter.ini and retrieve not set parameters\n",
|
||||
"\n",
|
||||
" hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')\n",
|
||||
"\n",
|
||||
" h_params = Namespace(**hparams_dict)\n",
|
||||
"\n",
|
||||
" # Let Datamodule pull what it wants\n",
|
||||
" datamodule = found_data_class.from_argparse_args(h_params)\n",
|
||||
"\n",
|
||||
" hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)\n",
|
||||
"\n",
|
||||
" return datamodule, found_model_class, hparams_dict\n",
|
||||
"\n",
|
||||
"def gather_predictions_and_labels(model, data_option):\n",
|
||||
" preds = list()\n",
|
||||
" labels = list()\n",
|
||||
" filenames = list()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for file_name, x, y in datamodule.datasets[data_option]:\n",
|
||||
" preds.append(model(x.unsqueeze(0)).main_out)\n",
|
||||
" labels.append(y)\n",
|
||||
" filenames.append(file_name)\n",
|
||||
" labels = np.stack(labels).squeeze()\n",
|
||||
" preds = np.stack(preds).squeeze()\n",
|
||||
" return preds, labels, filenames\n",
|
||||
"\n",
|
||||
"def build_tsne_dataframe(preds, labels):\n",
|
||||
" tsne = np.stack(TSNE().fit_transform(preds)).squeeze()\n",
|
||||
" tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])\n",
|
||||
"\n",
|
||||
" tsne_dataframe['labels'] = labels\n",
|
||||
" tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})\n",
|
||||
" return tsne_dataframe\n",
|
||||
"\n",
|
||||
"def plot_scatterplot(data, option):\n",
|
||||
" p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)\n",
|
||||
" p.set_title(f'TSNE - distribution of logits for {option}')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"def redo_predictions(experiment_path, preds, fnames, data_class):\n",
|
||||
" sorted_y = defaultdict(list)\n",
|
||||
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
||||
" sorted_y[fname].append(pred)\n",
|
||||
" sorted_y = dict(sorted_y)\n",
|
||||
"\n",
|
||||
" for file_name in sorted_y:\n",
|
||||
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" if data_class.n_classes > 2:\n",
|
||||
" pred = np.stack(\n",
|
||||
" [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
||||
" ).squeeze()\n",
|
||||
" class_names = {val: key for val, key in\n",
|
||||
" enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}\n",
|
||||
" else:\n",
|
||||
" pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]\n",
|
||||
" pred = np.stack(pred).squeeze()\n",
|
||||
" pred = np.where(pred > 0.5, 1, 0)\n",
|
||||
" class_names = {val: key for val, key in enumerate(['negative', 'positive'])}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],\n",
|
||||
" prediction=[class_names[x.item()] for x in pred]))\n",
|
||||
" result_file = Path(experiment_path / 'predictions_new.csv')\n",
|
||||
" if result_file.exists():\n",
|
||||
" try:\n",
|
||||
" result_file.unlink()\n",
|
||||
" except:\n",
|
||||
" print('File already existed')\n",
|
||||
" pass\n",
|
||||
" with result_file.open(mode='wb') as csv_file:\n",
|
||||
" df.to_csv(index=False, path_or_buf=csv_file)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def re_valida(preds, labels, fnames, data_class):\n",
|
||||
" sorted_y = defaultdict(list)\n",
|
||||
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
||||
" sorted_y[fname].append(pred)\n",
|
||||
" sorted_y = dict(sorted_y)\n",
|
||||
"\n",
|
||||
" for file_name in sorted_y:\n",
|
||||
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
||||
"\n",
|
||||
" for key, val in list(sorted_y.items()):\n",
|
||||
" if val.ndim > 1:\n",
|
||||
" val = val.mean(axis=0)\n",
|
||||
" print(val.ndim)\n",
|
||||
" if not val[0] > 0.8:\n",
|
||||
" val[0] = 0\n",
|
||||
" sorted_y[key] = val\n",
|
||||
"\n",
|
||||
" pred = np.stack(\n",
|
||||
" [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
||||
" ).squeeze()\n",
|
||||
"\n",
|
||||
" one_hot_targets = np.eye(data_class.n_classes)[pred]\n",
|
||||
"\n",
|
||||
" # Sklearn Scores\n",
|
||||
" print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% Utility Functions\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Selected Checkopint is output\\VisualTransformer\\VT_7899c07a4809a45c57cba58047cefb5e\\version_7\\ckpt_weights-v1.ckpt\n",
|
||||
"PrimatesLibrosaDatamodule\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
|
||||
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]\n",
|
||||
"print(f'Selected Checkopint is {checkpoint}')\n",
|
||||
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
|
||||
"print(load_hparams_from_yaml(hparams_yaml)['data_name'])\n",
|
||||
"# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE\n",
|
||||
"datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())\n",
|
||||
"# h_params.update(return_logits=True)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)",
|
||||
"\u001B[1;32m<ipython-input-6-c8e208607217>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mh_params\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0meval\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 2\u001B[0m \u001B[0mdatamodule\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprepare_data\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 202\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 203\u001B[0m \u001B[1;31m# load the state_dict on the model automatically\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 204\u001B[1;33m \u001B[0mmodel\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;34m'state_dict'\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 205\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 206\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36mload_state_dict\u001B[1;34m(self, state_dict, strict)\u001B[0m\n\u001B[0;32m 1049\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1050\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0merror_msgs\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1051\u001B[1;33m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001B[0m\u001B[0;32m 1052\u001B[0m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001B[0;32m 1053\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_IncompatibleKeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmissing_keys\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0munexpected_keys\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[1;31mRuntimeError\u001B[0m: Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30])."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()\n",
|
||||
"datamodule.prepare_data()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)\n",
|
||||
"# plot_scatterplot(tsne_dataframe, data_option)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"re_valida(predictions,labels_y, filenames, datamodule)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
+17
-10
@@ -5,6 +5,7 @@ from abc import ABC
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from util.loss_mixin import LossMixin
|
||||
@@ -59,7 +60,7 @@ class ValMixin:
|
||||
|
||||
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
||||
if self.params.n_classes <= 2:
|
||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()])
|
||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()])
|
||||
self.metrics.update(mean_sorted_y, target_y)
|
||||
else:
|
||||
y_max = torch.stack(
|
||||
@@ -97,8 +98,11 @@ class ValMixin:
|
||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||
|
||||
if self.params.n_classes <= 2:
|
||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||
mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||
mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1)
|
||||
# mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1)
|
||||
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
||||
|
||||
# Sklearn Scores
|
||||
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
||||
|
||||
@@ -129,6 +133,7 @@ class ValMixin:
|
||||
|
||||
for name, image in pl_images.items():
|
||||
self.logger.log_image(name, image, step=self.global_step)
|
||||
plt.close(image)
|
||||
pass
|
||||
|
||||
|
||||
@@ -162,12 +167,13 @@ class TestMixin:
|
||||
class_names = {val: key for val, key in
|
||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||
else:
|
||||
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
||||
pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||
pred = torch.stack(pred).squeeze()
|
||||
pred = torch.where(pred > 0.5, 1, 0)
|
||||
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
||||
|
||||
|
||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
||||
df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
|
||||
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
||||
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
||||
if result_file.exists():
|
||||
@@ -178,12 +184,13 @@ class TestMixin:
|
||||
pass
|
||||
with result_file.open(mode='wb') as csv_file:
|
||||
df.to_csv(index=False, path_or_buf=csv_file)
|
||||
with result_file.open(mode='rb') as csv_file:
|
||||
try:
|
||||
self.logger.neptunelogger.log_artifact(csv_file)
|
||||
except:
|
||||
print('No possible to send to neptune')
|
||||
pass
|
||||
if False:
|
||||
with result_file.open(mode='rb') as csv_file:
|
||||
try:
|
||||
self.logger.neptunelogger.log_artifact(csv_file)
|
||||
except:
|
||||
print('No possible to send to neptune')
|
||||
pass
|
||||
|
||||
|
||||
class CombinedModelMixins(LossMixin,
|
||||
|
||||
@@ -5,5 +5,6 @@ sr = 16000
|
||||
|
||||
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
||||
CCS_Root = Path(__file__).parent / 'data' / 'ComParE2021_CCS'
|
||||
MASK_Root = Path(__file__).parent / 'data' / 'ComParE2020_Mask'
|
||||
|
||||
N_CLASS_multi = 5
|
||||
|
||||
Reference in New Issue
Block a user