paper preperations and notebooks, optuna callbacks
This commit is contained in:
parent
7c88602776
commit
cec3a07d60
@ -10,7 +10,7 @@ data_name = CCSLibrosaDatamodule
|
|||||||
[data]
|
[data]
|
||||||
num_worker = 10
|
num_worker = 10
|
||||||
data_root = data
|
data_root = data
|
||||||
variable_length = True
|
variable_length = False
|
||||||
target_mel_length_in_seconds = 0.7
|
target_mel_length_in_seconds = 0.7
|
||||||
|
|
||||||
n_mels = 128
|
n_mels = 128
|
||||||
@ -49,6 +49,15 @@ dropout = 0.2
|
|||||||
lat_dim = 32
|
lat_dim = 32
|
||||||
filters = [16, 32, 64, 128]
|
filters = [16, 32, 64, 128]
|
||||||
|
|
||||||
|
[BandwiseConvClassifier]
|
||||||
|
weight_init = xavier_normal_
|
||||||
|
activation = gelu
|
||||||
|
use_bias = True
|
||||||
|
use_norm = True
|
||||||
|
dropout = 0.2
|
||||||
|
lat_dim = 32
|
||||||
|
filters = [16, 32, 64, 128]
|
||||||
|
|
||||||
[VisualTransformer]
|
[VisualTransformer]
|
||||||
weight_init = xavier_normal_
|
weight_init = xavier_normal_
|
||||||
activation = gelu
|
activation = gelu
|
||||||
@ -73,11 +82,13 @@ use_norm = True
|
|||||||
use_residual = True
|
use_residual = True
|
||||||
dropout = 0.2
|
dropout = 0.2
|
||||||
|
|
||||||
lat_dim = 32
|
mlp_dim = 6
|
||||||
|
lat_dim = 6
|
||||||
|
head_dim = 6
|
||||||
patch_size = 8
|
patch_size = 8
|
||||||
attn_depth = 12
|
attn_depth = 6
|
||||||
heads = 4
|
heads = 4
|
||||||
embedding_size = 128
|
embedding_size = 30
|
||||||
|
|
||||||
[HorizontalVisualTransformer]
|
[HorizontalVisualTransformer]
|
||||||
weight_init = xavier_normal_
|
weight_init = xavier_normal_
|
||||||
|
@ -118,7 +118,9 @@ class CompareBase(_BaseDataModule):
|
|||||||
lab_file = None
|
lab_file = None
|
||||||
|
|
||||||
for data_option in data_options:
|
for data_option in data_options:
|
||||||
if lab_file is not None:
|
if lab_file is None:
|
||||||
|
lab_file = f'{data_option}.csv'
|
||||||
|
elif lab_file is not None:
|
||||||
if any([x in lab_file for x in data_options]):
|
if any([x in lab_file for x in data_options]):
|
||||||
lab_file = f'{data_option}.csv'
|
lab_file = f'{data_option}.csv'
|
||||||
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
|
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
|
||||||
|
15
main.py
15
main.py
@ -76,15 +76,12 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad
|
|||||||
trainer.fit(model, datamodule)
|
trainer.fit(model, datamodule)
|
||||||
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
||||||
|
|
||||||
|
trainer.test(model=model, datamodule=datamodule, ckpt_path='best')
|
||||||
|
|
||||||
trainer.test(model=model, datamodule=datamodule)
|
return Namespace(model=model,
|
||||||
#except:
|
best_model_path=ckpt_callback.best_model_path,
|
||||||
# print('Test did not Suceed!')
|
best_model_score=ckpt_callback.best_model_score.item(),
|
||||||
# pass
|
max_score_monitor=score_callback.best_scores)
|
||||||
|
|
||||||
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
|
||||||
|
|
||||||
return score_callback.best_scores['PL_recall_score']
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -96,6 +93,6 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Start
|
# Start
|
||||||
# -----------------
|
# -----------------
|
||||||
run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)
|
print(run_lightning_loop(hparams, found_data_class, found_model_class, found_seed))
|
||||||
print('done')
|
print('done')
|
||||||
pass
|
pass
|
||||||
|
69
models/bandwise_conv_classifier.py
Normal file
69
models/bandwise_conv_classifier.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import inspect
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
|
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||||
|
from util.module_mixins import CombinedModelMixins
|
||||||
|
|
||||||
|
|
||||||
|
class BandwiseConvClassifier(CombinedModelMixins,
|
||||||
|
LightningBaseModule
|
||||||
|
):
|
||||||
|
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||||
|
use_bias, use_norm, dropout, lat_dim, filters,
|
||||||
|
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
|
||||||
|
loss, scheduler, lr_scheduler_parameter
|
||||||
|
):
|
||||||
|
# TODO: Move this to parent class, or make it much easieer to access....
|
||||||
|
a = dict(locals())
|
||||||
|
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||||
|
super(BandwiseConvClassifier, self).__init__(params)
|
||||||
|
|
||||||
|
# Model Paramters
|
||||||
|
# =============================================================================
|
||||||
|
# Additional parameters
|
||||||
|
self.n_band_sections = 8
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
# =============================================================================
|
||||||
|
self.split = Splitter(in_shape, self.n_band_sections)
|
||||||
|
|
||||||
|
k = 3
|
||||||
|
self.band_list = ModuleList()
|
||||||
|
for band in range(self.n_band_sections):
|
||||||
|
last_shape = self.split.shape[band]
|
||||||
|
conv_list = ModuleList()
|
||||||
|
for conv_filters in self.params.filters:
|
||||||
|
conv_list.append(ConvModule(last_shape, conv_filters, (k, k), conv_stride=(2, 2), conv_padding=2,
|
||||||
|
**self.params.module_kwargs))
|
||||||
|
last_shape = conv_list[-1].shape
|
||||||
|
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
|
||||||
|
# last_shape = self.conv_list[-1].shape
|
||||||
|
self.band_list.append(conv_list)
|
||||||
|
|
||||||
|
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||||
|
|
||||||
|
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||||
|
self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||||
|
|
||||||
|
# Make Decision between binary and Multiclass Classification
|
||||||
|
logits = n_classes if n_classes > 2 else 1
|
||||||
|
module_kwargs = self.params.module_kwargs
|
||||||
|
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||||
|
self.full_out = LinearModule(self.full_2.shape, logits, **module_kwargs)
|
||||||
|
|
||||||
|
def forward(self, batch, **kwargs):
|
||||||
|
tensors = self.split(batch)
|
||||||
|
for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
|
||||||
|
for conv in convs:
|
||||||
|
tensor = conv(tensor)
|
||||||
|
tensors[idx] = tensor
|
||||||
|
|
||||||
|
tensor = self.merge(tensors)
|
||||||
|
tensor = self.full_1(tensor)
|
||||||
|
tensor = self.full_2(tensor)
|
||||||
|
tensor = self.full_out(tensor)
|
||||||
|
return Namespace(main_out=tensor)
|
@ -3,8 +3,6 @@ from argparse import Namespace
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
|
||||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
|
||||||
from ml_lib.modules.blocks import LinearModule
|
from ml_lib.modules.blocks import LinearModule
|
||||||
from ml_lib.modules.model_parts import CNNEncoder
|
from ml_lib.modules.model_parts import CNNEncoder
|
||||||
from ml_lib.modules.util import (LightningBaseModule)
|
from ml_lib.modules.util import (LightningBaseModule)
|
||||||
@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
|
|||||||
|
|
||||||
tensor = self.classifier(tensor)
|
tensor = self.classifier(tensor)
|
||||||
return Namespace(main_out=tensor)
|
return Namespace(main_out=tensor)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
if self.params.n_classes > 2:
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
else:
|
|
||||||
return BinaryScores(self)(outputs)
|
|
||||||
|
@ -130,8 +130,6 @@ try:
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor)
|
return Namespace(main_out=tensor)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
|
||||||
except ImportError: # pragma: do not provide model class
|
except ImportError: # pragma: do not provide model class
|
||||||
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
@ -109,9 +109,3 @@ class Tester(CombinedModelMixins,
|
|||||||
tensor = self.to_cls_token(tensor[:, 0])
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=None)
|
return Namespace(main_out=tensor, attn_weights=None)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
if self.params.n_classes > 2:
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
else:
|
|
||||||
return BinaryScores(self)(outputs)
|
|
||||||
|
@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||||
|
|
||||||
if return_logits:
|
if return_logits:
|
||||||
outbound_activation = nn.Identity()
|
outbound_activation = nn.Identity
|
||||||
else:
|
else:
|
||||||
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||||
|
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
self.params.activation(),
|
self.params.activation(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, logits),
|
nn.Linear(self.params.lat_dim, logits),
|
||||||
outbound_activation
|
outbound_activation()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
if self.params.n_classes <= 2:
|
|
||||||
return BinaryScores(self)(outputs)
|
|
||||||
else:
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
@ -116,5 +116,3 @@ class HorizontalVisualTransformer(CombinedModelMixins,
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
@ -18,9 +18,10 @@ MIN_NUM_PATCHES = 16
|
|||||||
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||||
|
|
||||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||||
|
return_logits=False):
|
||||||
|
|
||||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||||
a = dict(locals())
|
a = dict(locals())
|
||||||
@ -47,14 +48,6 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||||
f'attention. Try decreasing your patch size'
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
# Correct the Embedding Dim
|
|
||||||
if not self.embed_dim % self.params.heads == 0:
|
|
||||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
|
||||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
|
||||||
f' of attention heads, is now: {self.embed_dim}')
|
|
||||||
for func in print, warnings.warn:
|
|
||||||
func(message)
|
|
||||||
|
|
||||||
# Utility Modules
|
# Utility Modules
|
||||||
self.autopad = AutoPadToShape((self.height, self.new_width))
|
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||||
self.dropout = nn.Dropout(self.params.dropout)
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
@ -62,10 +55,11 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
keepdim=False)
|
keepdim=False)
|
||||||
|
|
||||||
# Modules with Parameters
|
# Modules with Parameters
|
||||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||||
|
head_dim=self.params.head_dim,
|
||||||
heads=self.params.heads, depth=self.params.attn_depth,
|
heads=self.params.heads, depth=self.params.attn_depth,
|
||||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||||
activation=self.params.activation
|
activation=self.params.activation, use_residual=self.params.use_residual
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||||
@ -74,13 +68,17 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||||
self.to_cls_token = nn.Identity()
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||||
|
|
||||||
|
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
nn.LayerNorm(self.embed_dim),
|
nn.LayerNorm(self.embed_dim),
|
||||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
nn.GELU(),
|
self.params.activation(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, self.n_classes),
|
nn.Linear(self.params.lat_dim, logits),
|
||||||
nn.Softmax()
|
outbound_activation()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
@ -112,5 +110,3 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
61
multi_run.py
61
multi_run.py
@ -10,35 +10,54 @@ import itertools
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# Set new values
|
# Set new values
|
||||||
hparams_dict = dict(seed=range(1, 11),
|
hparams_dict = dict(seed=range(1, 6),
|
||||||
model_name=['CNNBaseline'],
|
# BandwiseConvClassifier, CNNBaseline, VisualTransformer, VerticalVisualTransformer
|
||||||
data_name=['CCSLibrosaDatamodule'], # 'CCSLibrosaDatamodule'],
|
model_name=['VisualTransformer'],
|
||||||
batch_size=[50],
|
# CCSLibrosaDatamodule, PrimatesLibrosaDatamodule,
|
||||||
max_epochs=[200],
|
data_name=['PrimatesLibrosaDatamodule'],
|
||||||
variable_length=[False], # THIS IS NEXT
|
batch_size=[30],
|
||||||
target_mel_length_in_seconds=[0.7],
|
max_epochs=[150],
|
||||||
random_apply_chance=[0.5], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
target_mel_length_in_seconds=[0.5],
|
||||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
outpath=['head_exp'],
|
||||||
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
|
||||||
noise_ratio=[0.3], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
|
||||||
mask_ratio=[0.3], # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
|
||||||
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
|
||||||
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||||
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
|
||||||
mlp_dim=[16], # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
|
||||||
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
|
||||||
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
|
||||||
attn_depth=[12], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
|
||||||
heads=[6], # trial.suggest_int('heads', 2, 16, step=2),
|
|
||||||
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
||||||
lr_scheduler_parameter=[0.94, 0.93, 0.95], # [0.98],
|
lr_scheduler_parameter=[0.95], # [0.95],
|
||||||
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
|
||||||
loss=['ce_loss'],
|
loss=['ce_loss'],
|
||||||
sampler=['WeightedRandomSampler'],
|
sampler=['WeightedRandomSampler'],
|
||||||
# trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
# trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
||||||
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
|
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Data Aug Parameters
|
||||||
|
hparams_dict.update(random_apply_chance=[0.3], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||||
|
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||||
|
shift_ratio=[0.2], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||||
|
noise_ratio=[0.4], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||||
|
mask_ratio=[0.2], # triaSl.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),)
|
||||||
|
)
|
||||||
|
if False:
|
||||||
|
# CNN Parameters:
|
||||||
|
hparams_dict.update(filters=[[16, 32, 64, 32]],
|
||||||
|
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||||
|
variable_length=[False], # THIS does not Work
|
||||||
|
lat_dim=[64], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Transfornmer Parameters:
|
||||||
|
hparams_dict.update(lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||||
|
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||||
|
mlp_dim=[16],
|
||||||
|
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||||
|
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||||
|
attn_depth=[14], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||||
|
heads=[2,4,6,8,10], # trial.suggest_int('heads', 2, 16, step=2),
|
||||||
|
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||||
|
variable_length=[False], # THIS does not Work
|
||||||
|
)
|
||||||
|
|
||||||
keys, values = zip(*hparams_dict.items())
|
keys, values = zip(*hparams_dict.items())
|
||||||
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||||
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
||||||
|
File diff suppressed because one or more lines are too long
2644
notebooks/Study Plots.ipynb
Normal file
2644
notebooks/Study Plots.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import optuna as optuna
|
import optuna as optuna
|
||||||
from optuna.integration import PyTorchLightningPruningCallback
|
from optuna.integration import PyTorchLightningPruningCallback
|
||||||
@ -8,6 +9,56 @@ from optuna.integration import PyTorchLightningPruningCallback
|
|||||||
from main import run_lightning_loop
|
from main import run_lightning_loop
|
||||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||||
|
|
||||||
|
|
||||||
|
class ContiniousSavingCallback:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def study(self):
|
||||||
|
return self._study
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tmp_study_path(self):
|
||||||
|
return Path(self.root) / f'TMP_{self.study.study_name}_trial{self.study.trials[-1].number}.pkl'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_study_path(self):
|
||||||
|
return Path(self.root) / f'FINAL_{self.study.study_name}_' \
|
||||||
|
f'best_{self.study.best_trial.number}_' \
|
||||||
|
f'score_{self.study.best_value}.pkl'
|
||||||
|
|
||||||
|
def __init__(self, root:Union[str, Path], study: optuna.Study):
|
||||||
|
self._study = study
|
||||||
|
self.root = Path(root)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_to_disk(object, path):
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(exist_ok=True)
|
||||||
|
if path.exists():
|
||||||
|
path.unlink(missing_ok=True)
|
||||||
|
with path.open(mode='wb') as f:
|
||||||
|
pickle.dump(object, f)
|
||||||
|
|
||||||
|
def save_final(self):
|
||||||
|
self._write_to_disk(self.study, self.final_study_path())
|
||||||
|
|
||||||
|
def clean_up(self):
|
||||||
|
temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*')
|
||||||
|
for temp_study_file in temp_study_files:
|
||||||
|
temp_study_file.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
|
||||||
|
self._write_to_disk(study, self.tmp_study_path(trial.number))
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.save_final()
|
||||||
|
self.clean_up()
|
||||||
|
|
||||||
|
|
||||||
def optimize(trial: optuna.Trial):
|
def optimize(trial: optuna.Trial):
|
||||||
# Optuna configuration
|
# Optuna configuration
|
||||||
folder = Path('study')
|
folder = Path('study')
|
||||||
@ -19,10 +70,10 @@ def optimize(trial: optuna.Trial):
|
|||||||
lr_scheduler_parameter = None
|
lr_scheduler_parameter = None
|
||||||
|
|
||||||
optuna_suggestions = dict(
|
optuna_suggestions = dict(
|
||||||
model_name='CNNBaseline',
|
model_name='VisualTransformer',
|
||||||
data_name='MaskLibrosaDatamodule',
|
data_name='CCSLibrosaDatamodule',
|
||||||
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
||||||
max_epochs=75,
|
max_epochs=200,
|
||||||
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
||||||
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||||
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||||
@ -44,7 +95,7 @@ def optimize(trial: optuna.Trial):
|
|||||||
for layer_idx in range(model_depth):
|
for layer_idx in range(model_depth):
|
||||||
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
||||||
optuna_suggestions.update(filters=filters)
|
optuna_suggestions.update(filters=filters)
|
||||||
elif optuna_suggestions['model_name'] == 'VisualTransformer':
|
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
|
||||||
transformer_dict = dict(
|
transformer_dict = dict(
|
||||||
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||||
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||||
@ -62,8 +113,10 @@ def optimize(trial: optuna.Trial):
|
|||||||
'_parameters.ini', overrides=optuna_suggestions)
|
'_parameters.ini', overrides=optuna_suggestions)
|
||||||
h_params = Namespace(**h_params)
|
h_params = Namespace(**h_params)
|
||||||
try:
|
try:
|
||||||
best_score = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
||||||
additional_callbacks=pruning_callback, seed=seed)
|
additional_callbacks=pruning_callback, seed=seed)
|
||||||
|
best_score = results.best_model_score
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
best_score = 0
|
best_score = 0
|
||||||
@ -71,22 +124,18 @@ def optimize(trial: optuna.Trial):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
optuna_study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
||||||
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
with ContiniousSavingCallback('study', optuna_study) as continious_save_callback:
|
||||||
study.optimize(optimize, n_trials=100)
|
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
||||||
|
optuna_study.optimize(optimize, n_trials=200, show_progress_bar=True, callbacks=[continious_save_callback])
|
||||||
|
|
||||||
print("Number of finished trials: {}".format(len(study.trials)))
|
print("Number of finished trials: {}".format(len(optuna_study.trials)))
|
||||||
|
|
||||||
print("Best trial:")
|
print("Best trial:")
|
||||||
trial = study.best_trial
|
trial = optuna_study.best_trial
|
||||||
|
|
||||||
print(" Value: {}".format(trial.value))
|
print(" Value: {}".format(trial.value))
|
||||||
|
|
||||||
print(" Params: ")
|
print(" Params: ")
|
||||||
for key, value in trial.params.items():
|
for key, value in trial.params.items():
|
||||||
print(" {}: {}".format(key, value))
|
print(" {}: {}".format(key, value))
|
||||||
|
|
||||||
optuna_study_object = Path('study') / 'study.pkl'
|
|
||||||
optuna_study_object.parent.mkdir(exist_ok=True)
|
|
||||||
with optuna_study_object.open(mode='wb') as f:
|
|
||||||
pickle.dump(study, f)
|
|
||||||
|
@ -23,7 +23,7 @@ def rebuild_dataset(h_params, data_class):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
for dataset in [MaskLibrosaDatamodule]: # [PrimatesLibrosaDatamodule, CCSLibrosaDatamodule]:
|
for dataset in [CCSLibrosaDatamodule]:
|
||||||
# Parse comandline args, read config and get model
|
# Parse comandline args, read config and get model
|
||||||
cmd_args, _, _, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
cmd_args, _, _, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
||||||
|
|
||||||
|
341
reload model.ipynb
Normal file
341
reload model.ipynb
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% IMPORTS\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from collections import defaultdict\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from natsort import natsorted\n",
|
||||||
|
"from pytorch_lightning.core.saving import *\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from sklearn.manifold import TSNE\n",
|
||||||
|
"\n",
|
||||||
|
"import seaborn as sns\n",
|
||||||
|
"from matplotlib import pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from ml_lib.metrics.binary_class_classifictaion import BinaryScores\n",
|
||||||
|
"from ml_lib.utils.tools import locate_and_import_class\n",
|
||||||
|
"_ROOT = Path()\n",
|
||||||
|
"out_path = 'output'\n",
|
||||||
|
"model_name = 'VisualTransformer'\n",
|
||||||
|
"exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'\n",
|
||||||
|
"version = 'version_7'"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%M Path resolving and variables\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plt.style.use('default')\n",
|
||||||
|
"sns.set_palette('Dark2')\n",
|
||||||
|
"\n",
|
||||||
|
"tex_fonts = {\n",
|
||||||
|
" # Use LaTeX to write all text\n",
|
||||||
|
" \"text.usetex\": True,\n",
|
||||||
|
" \"font.family\": \"serif\",\n",
|
||||||
|
" # Use 10pt font in plots, to match 10pt font in document\n",
|
||||||
|
" \"axes.labelsize\": 10,\n",
|
||||||
|
" \"font.size\": 10,\n",
|
||||||
|
" # Make the legend/label fonts a little smaller\n",
|
||||||
|
" \"legend.fontsize\": 8,\n",
|
||||||
|
" \"xtick.labelsize\": 8,\n",
|
||||||
|
" \"ytick.labelsize\": 8\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# plt.rcParams.update(tex_fonts)\n",
|
||||||
|
"\n",
|
||||||
|
"Path('figures').mkdir(exist_ok=True)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Seaborn Settings\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def reconstruct_model_data_params(yaml_file_path: str):\n",
|
||||||
|
" hparams_dict = load_hparams_from_yaml(yaml_file_path)\n",
|
||||||
|
"\n",
|
||||||
|
" # Try to get model_name and data_name from yaml:\n",
|
||||||
|
" model_name = hparams_dict['model_name']\n",
|
||||||
|
" data_name = hparams_dict['data_name']\n",
|
||||||
|
" # Try to find the original model and data class by name:\n",
|
||||||
|
" found_data_class = locate_and_import_class(data_name, 'datasets')\n",
|
||||||
|
" found_model_class = locate_and_import_class(model_name, 'models')\n",
|
||||||
|
" # Possible way of automatic loading args:\n",
|
||||||
|
" # args = inspect.signature(found_data_class)\n",
|
||||||
|
" # then access _parameter.ini and retrieve not set parameters\n",
|
||||||
|
"\n",
|
||||||
|
" hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')\n",
|
||||||
|
"\n",
|
||||||
|
" h_params = Namespace(**hparams_dict)\n",
|
||||||
|
"\n",
|
||||||
|
" # Let Datamodule pull what it wants\n",
|
||||||
|
" datamodule = found_data_class.from_argparse_args(h_params)\n",
|
||||||
|
"\n",
|
||||||
|
" hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)\n",
|
||||||
|
"\n",
|
||||||
|
" return datamodule, found_model_class, hparams_dict\n",
|
||||||
|
"\n",
|
||||||
|
"def gather_predictions_and_labels(model, data_option):\n",
|
||||||
|
" preds = list()\n",
|
||||||
|
" labels = list()\n",
|
||||||
|
" filenames = list()\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for file_name, x, y in datamodule.datasets[data_option]:\n",
|
||||||
|
" preds.append(model(x.unsqueeze(0)).main_out)\n",
|
||||||
|
" labels.append(y)\n",
|
||||||
|
" filenames.append(file_name)\n",
|
||||||
|
" labels = np.stack(labels).squeeze()\n",
|
||||||
|
" preds = np.stack(preds).squeeze()\n",
|
||||||
|
" return preds, labels, filenames\n",
|
||||||
|
"\n",
|
||||||
|
"def build_tsne_dataframe(preds, labels):\n",
|
||||||
|
" tsne = np.stack(TSNE().fit_transform(preds)).squeeze()\n",
|
||||||
|
" tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])\n",
|
||||||
|
"\n",
|
||||||
|
" tsne_dataframe['labels'] = labels\n",
|
||||||
|
" tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})\n",
|
||||||
|
" return tsne_dataframe\n",
|
||||||
|
"\n",
|
||||||
|
"def plot_scatterplot(data, option):\n",
|
||||||
|
" p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)\n",
|
||||||
|
" p.set_title(f'TSNE - distribution of logits for {option}')\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"def redo_predictions(experiment_path, preds, fnames, data_class):\n",
|
||||||
|
" sorted_y = defaultdict(list)\n",
|
||||||
|
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
||||||
|
" sorted_y[fname].append(pred)\n",
|
||||||
|
" sorted_y = dict(sorted_y)\n",
|
||||||
|
"\n",
|
||||||
|
" for file_name in sorted_y:\n",
|
||||||
|
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" if data_class.n_classes > 2:\n",
|
||||||
|
" pred = np.stack(\n",
|
||||||
|
" [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
||||||
|
" ).squeeze()\n",
|
||||||
|
" class_names = {val: key for val, key in\n",
|
||||||
|
" enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}\n",
|
||||||
|
" else:\n",
|
||||||
|
" pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]\n",
|
||||||
|
" pred = np.stack(pred).squeeze()\n",
|
||||||
|
" pred = np.where(pred > 0.5, 1, 0)\n",
|
||||||
|
" class_names = {val: key for val, key in enumerate(['negative', 'positive'])}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],\n",
|
||||||
|
" prediction=[class_names[x.item()] for x in pred]))\n",
|
||||||
|
" result_file = Path(experiment_path / 'predictions_new.csv')\n",
|
||||||
|
" if result_file.exists():\n",
|
||||||
|
" try:\n",
|
||||||
|
" result_file.unlink()\n",
|
||||||
|
" except:\n",
|
||||||
|
" print('File already existed')\n",
|
||||||
|
" pass\n",
|
||||||
|
" with result_file.open(mode='wb') as csv_file:\n",
|
||||||
|
" df.to_csv(index=False, path_or_buf=csv_file)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def re_valida(preds, labels, fnames, data_class):\n",
|
||||||
|
" sorted_y = defaultdict(list)\n",
|
||||||
|
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
||||||
|
" sorted_y[fname].append(pred)\n",
|
||||||
|
" sorted_y = dict(sorted_y)\n",
|
||||||
|
"\n",
|
||||||
|
" for file_name in sorted_y:\n",
|
||||||
|
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
||||||
|
"\n",
|
||||||
|
" for key, val in list(sorted_y.items()):\n",
|
||||||
|
" if val.ndim > 1:\n",
|
||||||
|
" val = val.mean(axis=0)\n",
|
||||||
|
" print(val.ndim)\n",
|
||||||
|
" if not val[0] > 0.8:\n",
|
||||||
|
" val[0] = 0\n",
|
||||||
|
" sorted_y[key] = val\n",
|
||||||
|
"\n",
|
||||||
|
" pred = np.stack(\n",
|
||||||
|
" [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
||||||
|
" ).squeeze()\n",
|
||||||
|
"\n",
|
||||||
|
" one_hot_targets = np.eye(data_class.n_classes)[pred]\n",
|
||||||
|
"\n",
|
||||||
|
" # Sklearn Scores\n",
|
||||||
|
" print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Utility Functions\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Selected Checkopint is output\\VisualTransformer\\VT_7899c07a4809a45c57cba58047cefb5e\\version_7\\ckpt_weights-v1.ckpt\n",
|
||||||
|
"PrimatesLibrosaDatamodule\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
|
||||||
|
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]\n",
|
||||||
|
"print(f'Selected Checkopint is {checkpoint}')\n",
|
||||||
|
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
|
||||||
|
"print(load_hparams_from_yaml(hparams_yaml)['data_name'])\n",
|
||||||
|
"# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE\n",
|
||||||
|
"datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())\n",
|
||||||
|
"# h_params.update(return_logits=True)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "RuntimeError",
|
||||||
|
"evalue": "Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||||
|
"\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)",
|
||||||
|
"\u001B[1;32m<ipython-input-6-c8e208607217>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mh_params\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0meval\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 2\u001B[0m \u001B[0mdatamodule\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprepare_data\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 202\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 203\u001B[0m \u001B[1;31m# load the state_dict on the model automatically\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 204\u001B[1;33m \u001B[0mmodel\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;34m'state_dict'\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 205\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 206\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36mload_state_dict\u001B[1;34m(self, state_dict, strict)\u001B[0m\n\u001B[0;32m 1049\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1050\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0merror_msgs\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1051\u001B[1;33m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001B[0m\u001B[0;32m 1052\u001B[0m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001B[0;32m 1053\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_IncompatibleKeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmissing_keys\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0munexpected_keys\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;31mRuntimeError\u001B[0m: Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30])."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()\n",
|
||||||
|
"datamodule.prepare_data()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)\n",
|
||||||
|
"# plot_scatterplot(tsne_dataframe, data_option)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"re_valida(predictions,labels_y, filenames, datamodule)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@ -5,6 +5,7 @@ from abc import ABC
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from util.loss_mixin import LossMixin
|
from util.loss_mixin import LossMixin
|
||||||
@ -59,7 +60,7 @@ class ValMixin:
|
|||||||
|
|
||||||
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
||||||
if self.params.n_classes <= 2:
|
if self.params.n_classes <= 2:
|
||||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()])
|
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()])
|
||||||
self.metrics.update(mean_sorted_y, target_y)
|
self.metrics.update(mean_sorted_y, target_y)
|
||||||
else:
|
else:
|
||||||
y_max = torch.stack(
|
y_max = torch.stack(
|
||||||
@ -97,8 +98,11 @@ class ValMixin:
|
|||||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||||
|
|
||||||
if self.params.n_classes <= 2:
|
if self.params.n_classes <= 2:
|
||||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||||
|
mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1)
|
||||||
|
# mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1)
|
||||||
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
||||||
|
|
||||||
# Sklearn Scores
|
# Sklearn Scores
|
||||||
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
||||||
|
|
||||||
@ -129,6 +133,7 @@ class ValMixin:
|
|||||||
|
|
||||||
for name, image in pl_images.items():
|
for name, image in pl_images.items():
|
||||||
self.logger.log_image(name, image, step=self.global_step)
|
self.logger.log_image(name, image, step=self.global_step)
|
||||||
|
plt.close(image)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -162,12 +167,13 @@ class TestMixin:
|
|||||||
class_names = {val: key for val, key in
|
class_names = {val: key for val, key in
|
||||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||||
else:
|
else:
|
||||||
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||||
|
pred = torch.stack(pred).squeeze()
|
||||||
pred = torch.where(pred > 0.5, 1, 0)
|
pred = torch.where(pred > 0.5, 1, 0)
|
||||||
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
||||||
|
|
||||||
|
|
||||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
|
||||||
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
||||||
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
||||||
if result_file.exists():
|
if result_file.exists():
|
||||||
@ -178,12 +184,13 @@ class TestMixin:
|
|||||||
pass
|
pass
|
||||||
with result_file.open(mode='wb') as csv_file:
|
with result_file.open(mode='wb') as csv_file:
|
||||||
df.to_csv(index=False, path_or_buf=csv_file)
|
df.to_csv(index=False, path_or_buf=csv_file)
|
||||||
with result_file.open(mode='rb') as csv_file:
|
if False:
|
||||||
try:
|
with result_file.open(mode='rb') as csv_file:
|
||||||
self.logger.neptunelogger.log_artifact(csv_file)
|
try:
|
||||||
except:
|
self.logger.neptunelogger.log_artifact(csv_file)
|
||||||
print('No possible to send to neptune')
|
except:
|
||||||
pass
|
print('No possible to send to neptune')
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CombinedModelMixins(LossMixin,
|
class CombinedModelMixins(LossMixin,
|
||||||
|
@ -5,5 +5,6 @@ sr = 16000
|
|||||||
|
|
||||||
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
||||||
CCS_Root = Path(__file__).parent / 'data' / 'ComParE2021_CCS'
|
CCS_Root = Path(__file__).parent / 'data' / 'ComParE2021_CCS'
|
||||||
|
MASK_Root = Path(__file__).parent / 'data' / 'ComParE2020_Mask'
|
||||||
|
|
||||||
N_CLASS_multi = 5
|
N_CLASS_multi = 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user