paper preperations and notebooks, optuna callbacks
This commit is contained in:
+15
-4
@@ -10,7 +10,7 @@ data_name = CCSLibrosaDatamodule
|
|||||||
[data]
|
[data]
|
||||||
num_worker = 10
|
num_worker = 10
|
||||||
data_root = data
|
data_root = data
|
||||||
variable_length = True
|
variable_length = False
|
||||||
target_mel_length_in_seconds = 0.7
|
target_mel_length_in_seconds = 0.7
|
||||||
|
|
||||||
n_mels = 128
|
n_mels = 128
|
||||||
@@ -49,6 +49,15 @@ dropout = 0.2
|
|||||||
lat_dim = 32
|
lat_dim = 32
|
||||||
filters = [16, 32, 64, 128]
|
filters = [16, 32, 64, 128]
|
||||||
|
|
||||||
|
[BandwiseConvClassifier]
|
||||||
|
weight_init = xavier_normal_
|
||||||
|
activation = gelu
|
||||||
|
use_bias = True
|
||||||
|
use_norm = True
|
||||||
|
dropout = 0.2
|
||||||
|
lat_dim = 32
|
||||||
|
filters = [16, 32, 64, 128]
|
||||||
|
|
||||||
[VisualTransformer]
|
[VisualTransformer]
|
||||||
weight_init = xavier_normal_
|
weight_init = xavier_normal_
|
||||||
activation = gelu
|
activation = gelu
|
||||||
@@ -73,11 +82,13 @@ use_norm = True
|
|||||||
use_residual = True
|
use_residual = True
|
||||||
dropout = 0.2
|
dropout = 0.2
|
||||||
|
|
||||||
lat_dim = 32
|
mlp_dim = 6
|
||||||
|
lat_dim = 6
|
||||||
|
head_dim = 6
|
||||||
patch_size = 8
|
patch_size = 8
|
||||||
attn_depth = 12
|
attn_depth = 6
|
||||||
heads = 4
|
heads = 4
|
||||||
embedding_size = 128
|
embedding_size = 30
|
||||||
|
|
||||||
[HorizontalVisualTransformer]
|
[HorizontalVisualTransformer]
|
||||||
weight_init = xavier_normal_
|
weight_init = xavier_normal_
|
||||||
|
|||||||
@@ -118,7 +118,9 @@ class CompareBase(_BaseDataModule):
|
|||||||
lab_file = None
|
lab_file = None
|
||||||
|
|
||||||
for data_option in data_options:
|
for data_option in data_options:
|
||||||
if lab_file is not None:
|
if lab_file is None:
|
||||||
|
lab_file = f'{data_option}.csv'
|
||||||
|
elif lab_file is not None:
|
||||||
if any([x in lab_file for x in data_options]):
|
if any([x in lab_file for x in data_options]):
|
||||||
lab_file = f'{data_option}.csv'
|
lab_file = f'{data_option}.csv'
|
||||||
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
|
dataset = self._load_from_file(lab_file, data_option, rebuild=True)
|
||||||
|
|||||||
@@ -76,15 +76,12 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad
|
|||||||
trainer.fit(model, datamodule)
|
trainer.fit(model, datamodule)
|
||||||
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
trainer.save_checkpoint(logger.save_dir / 'last_weights.ckpt')
|
||||||
|
|
||||||
|
trainer.test(model=model, datamodule=datamodule, ckpt_path='best')
|
||||||
|
|
||||||
trainer.test(model=model, datamodule=datamodule)
|
return Namespace(model=model,
|
||||||
#except:
|
best_model_path=ckpt_callback.best_model_path,
|
||||||
# print('Test did not Suceed!')
|
best_model_score=ckpt_callback.best_model_score.item(),
|
||||||
# pass
|
max_score_monitor=score_callback.best_scores)
|
||||||
|
|
||||||
logger.log_metrics(score_callback.best_scores, step=trainer.global_step+1)
|
|
||||||
|
|
||||||
return score_callback.best_scores['PL_recall_score']
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -96,6 +93,6 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Start
|
# Start
|
||||||
# -----------------
|
# -----------------
|
||||||
run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)
|
print(run_lightning_loop(hparams, found_data_class, found_model_class, found_seed))
|
||||||
print('done')
|
print('done')
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
import inspect
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
|
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||||
|
from util.module_mixins import CombinedModelMixins
|
||||||
|
|
||||||
|
|
||||||
|
class BandwiseConvClassifier(CombinedModelMixins,
|
||||||
|
LightningBaseModule
|
||||||
|
):
|
||||||
|
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||||
|
use_bias, use_norm, dropout, lat_dim, filters,
|
||||||
|
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
|
||||||
|
loss, scheduler, lr_scheduler_parameter
|
||||||
|
):
|
||||||
|
# TODO: Move this to parent class, or make it much easieer to access....
|
||||||
|
a = dict(locals())
|
||||||
|
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||||
|
super(BandwiseConvClassifier, self).__init__(params)
|
||||||
|
|
||||||
|
# Model Paramters
|
||||||
|
# =============================================================================
|
||||||
|
# Additional parameters
|
||||||
|
self.n_band_sections = 8
|
||||||
|
|
||||||
|
# Modules
|
||||||
|
# =============================================================================
|
||||||
|
self.split = Splitter(in_shape, self.n_band_sections)
|
||||||
|
|
||||||
|
k = 3
|
||||||
|
self.band_list = ModuleList()
|
||||||
|
for band in range(self.n_band_sections):
|
||||||
|
last_shape = self.split.shape[band]
|
||||||
|
conv_list = ModuleList()
|
||||||
|
for conv_filters in self.params.filters:
|
||||||
|
conv_list.append(ConvModule(last_shape, conv_filters, (k, k), conv_stride=(2, 2), conv_padding=2,
|
||||||
|
**self.params.module_kwargs))
|
||||||
|
last_shape = conv_list[-1].shape
|
||||||
|
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
|
||||||
|
# last_shape = self.conv_list[-1].shape
|
||||||
|
self.band_list.append(conv_list)
|
||||||
|
|
||||||
|
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||||
|
|
||||||
|
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||||
|
self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||||
|
|
||||||
|
# Make Decision between binary and Multiclass Classification
|
||||||
|
logits = n_classes if n_classes > 2 else 1
|
||||||
|
module_kwargs = self.params.module_kwargs
|
||||||
|
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||||
|
self.full_out = LinearModule(self.full_2.shape, logits, **module_kwargs)
|
||||||
|
|
||||||
|
def forward(self, batch, **kwargs):
|
||||||
|
tensors = self.split(batch)
|
||||||
|
for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
|
||||||
|
for conv in convs:
|
||||||
|
tensor = conv(tensor)
|
||||||
|
tensors[idx] = tensor
|
||||||
|
|
||||||
|
tensor = self.merge(tensors)
|
||||||
|
tensor = self.full_1(tensor)
|
||||||
|
tensor = self.full_2(tensor)
|
||||||
|
tensor = self.full_out(tensor)
|
||||||
|
return Namespace(main_out=tensor)
|
||||||
@@ -3,8 +3,6 @@ from argparse import Namespace
|
|||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
|
||||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
|
||||||
from ml_lib.modules.blocks import LinearModule
|
from ml_lib.modules.blocks import LinearModule
|
||||||
from ml_lib.modules.model_parts import CNNEncoder
|
from ml_lib.modules.model_parts import CNNEncoder
|
||||||
from ml_lib.modules.util import (LightningBaseModule)
|
from ml_lib.modules.util import (LightningBaseModule)
|
||||||
@@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
|
|||||||
|
|
||||||
tensor = self.classifier(tensor)
|
tensor = self.classifier(tensor)
|
||||||
return Namespace(main_out=tensor)
|
return Namespace(main_out=tensor)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
if self.params.n_classes > 2:
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
else:
|
|
||||||
return BinaryScores(self)(outputs)
|
|
||||||
|
|||||||
@@ -130,8 +130,6 @@ try:
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor)
|
return Namespace(main_out=tensor)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
|
||||||
except ImportError: # pragma: do not provide model class
|
except ImportError: # pragma: do not provide model class
|
||||||
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
||||||
|
|||||||
@@ -109,9 +109,3 @@ class Tester(CombinedModelMixins,
|
|||||||
tensor = self.to_cls_token(tensor[:, 0])
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=None)
|
return Namespace(main_out=tensor, attn_weights=None)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
if self.params.n_classes > 2:
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
else:
|
|
||||||
return BinaryScores(self)(outputs)
|
|
||||||
|
|||||||
@@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||||
|
|
||||||
if return_logits:
|
if return_logits:
|
||||||
outbound_activation = nn.Identity()
|
outbound_activation = nn.Identity
|
||||||
else:
|
else:
|
||||||
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||||
|
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
@@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
self.params.activation(),
|
self.params.activation(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, logits),
|
nn.Linear(self.params.lat_dim, logits),
|
||||||
outbound_activation
|
outbound_activation()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
@@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
if self.params.n_classes <= 2:
|
|
||||||
return BinaryScores(self)(outputs)
|
|
||||||
else:
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
|||||||
@@ -116,5 +116,3 @@ class HorizontalVisualTransformer(CombinedModelMixins,
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ MIN_NUM_PATCHES = 16
|
|||||||
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||||
|
|
||||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||||
|
return_logits=False):
|
||||||
|
|
||||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||||
a = dict(locals())
|
a = dict(locals())
|
||||||
@@ -47,14 +48,6 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||||
f'attention. Try decreasing your patch size'
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
# Correct the Embedding Dim
|
|
||||||
if not self.embed_dim % self.params.heads == 0:
|
|
||||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
|
||||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
|
||||||
f' of attention heads, is now: {self.embed_dim}')
|
|
||||||
for func in print, warnings.warn:
|
|
||||||
func(message)
|
|
||||||
|
|
||||||
# Utility Modules
|
# Utility Modules
|
||||||
self.autopad = AutoPadToShape((self.height, self.new_width))
|
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||||
self.dropout = nn.Dropout(self.params.dropout)
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
@@ -62,10 +55,11 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
keepdim=False)
|
keepdim=False)
|
||||||
|
|
||||||
# Modules with Parameters
|
# Modules with Parameters
|
||||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||||
|
head_dim=self.params.head_dim,
|
||||||
heads=self.params.heads, depth=self.params.attn_depth,
|
heads=self.params.heads, depth=self.params.attn_depth,
|
||||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||||
activation=self.params.activation
|
activation=self.params.activation, use_residual=self.params.use_residual
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||||
@@ -74,13 +68,17 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||||
self.to_cls_token = nn.Identity()
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
|
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||||
|
|
||||||
|
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
nn.LayerNorm(self.embed_dim),
|
nn.LayerNorm(self.embed_dim),
|
||||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
nn.GELU(),
|
self.params.activation(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, self.n_classes),
|
nn.Linear(self.params.lat_dim, logits),
|
||||||
nn.Softmax()
|
outbound_activation()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None, return_attn_weights=False):
|
def forward(self, x, mask=None, return_attn_weights=False):
|
||||||
@@ -112,5 +110,3 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
|||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||||
|
|
||||||
def additional_scores(self, outputs):
|
|
||||||
return MultiClassScores(self)(outputs)
|
|
||||||
|
|||||||
+40
-21
@@ -10,35 +10,54 @@ import itertools
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# Set new values
|
# Set new values
|
||||||
hparams_dict = dict(seed=range(1, 11),
|
hparams_dict = dict(seed=range(1, 6),
|
||||||
model_name=['CNNBaseline'],
|
# BandwiseConvClassifier, CNNBaseline, VisualTransformer, VerticalVisualTransformer
|
||||||
data_name=['CCSLibrosaDatamodule'], # 'CCSLibrosaDatamodule'],
|
model_name=['VisualTransformer'],
|
||||||
batch_size=[50],
|
# CCSLibrosaDatamodule, PrimatesLibrosaDatamodule,
|
||||||
max_epochs=[200],
|
data_name=['PrimatesLibrosaDatamodule'],
|
||||||
variable_length=[False], # THIS IS NEXT
|
batch_size=[30],
|
||||||
target_mel_length_in_seconds=[0.7],
|
max_epochs=[150],
|
||||||
random_apply_chance=[0.5], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
target_mel_length_in_seconds=[0.5],
|
||||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
outpath=['head_exp'],
|
||||||
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
|
||||||
noise_ratio=[0.3], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
|
||||||
mask_ratio=[0.3], # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
|
||||||
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
|
||||||
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||||
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
|
||||||
mlp_dim=[16], # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
|
||||||
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
|
||||||
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
|
||||||
attn_depth=[12], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
|
||||||
heads=[6], # trial.suggest_int('heads', 2, 16, step=2),
|
|
||||||
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
||||||
lr_scheduler_parameter=[0.94, 0.93, 0.95], # [0.98],
|
lr_scheduler_parameter=[0.95], # [0.95],
|
||||||
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
|
||||||
loss=['ce_loss'],
|
loss=['ce_loss'],
|
||||||
sampler=['WeightedRandomSampler'],
|
sampler=['WeightedRandomSampler'],
|
||||||
# trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
# trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
||||||
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
|
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Data Aug Parameters
|
||||||
|
hparams_dict.update(random_apply_chance=[0.3], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||||
|
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||||
|
shift_ratio=[0.2], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||||
|
noise_ratio=[0.4], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||||
|
mask_ratio=[0.2], # triaSl.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),)
|
||||||
|
)
|
||||||
|
if False:
|
||||||
|
# CNN Parameters:
|
||||||
|
hparams_dict.update(filters=[[16, 32, 64, 32]],
|
||||||
|
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||||
|
variable_length=[False], # THIS does not Work
|
||||||
|
lat_dim=[64], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Transfornmer Parameters:
|
||||||
|
hparams_dict.update(lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||||
|
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||||
|
mlp_dim=[16],
|
||||||
|
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||||
|
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||||
|
attn_depth=[14], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||||
|
heads=[2,4,6,8,10], # trial.suggest_int('heads', 2, 16, step=2),
|
||||||
|
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||||
|
variable_length=[False], # THIS does not Work
|
||||||
|
)
|
||||||
|
|
||||||
keys, values = zip(*hparams_dict.items())
|
keys, values = zip(*hparams_dict.items())
|
||||||
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||||
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+397
-636
File diff suppressed because one or more lines are too long
+68
-22
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+64
-15
@@ -1,6 +1,7 @@
|
|||||||
import pickle
|
import pickle
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import optuna as optuna
|
import optuna as optuna
|
||||||
from optuna.integration import PyTorchLightningPruningCallback
|
from optuna.integration import PyTorchLightningPruningCallback
|
||||||
@@ -8,6 +9,56 @@ from optuna.integration import PyTorchLightningPruningCallback
|
|||||||
from main import run_lightning_loop
|
from main import run_lightning_loop
|
||||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||||
|
|
||||||
|
|
||||||
|
class ContiniousSavingCallback:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def study(self):
|
||||||
|
return self._study
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tmp_study_path(self):
|
||||||
|
return Path(self.root) / f'TMP_{self.study.study_name}_trial{self.study.trials[-1].number}.pkl'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_study_path(self):
|
||||||
|
return Path(self.root) / f'FINAL_{self.study.study_name}_' \
|
||||||
|
f'best_{self.study.best_trial.number}_' \
|
||||||
|
f'score_{self.study.best_value}.pkl'
|
||||||
|
|
||||||
|
def __init__(self, root:Union[str, Path], study: optuna.Study):
|
||||||
|
self._study = study
|
||||||
|
self.root = Path(root)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _write_to_disk(object, path):
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(exist_ok=True)
|
||||||
|
if path.exists():
|
||||||
|
path.unlink(missing_ok=True)
|
||||||
|
with path.open(mode='wb') as f:
|
||||||
|
pickle.dump(object, f)
|
||||||
|
|
||||||
|
def save_final(self):
|
||||||
|
self._write_to_disk(self.study, self.final_study_path())
|
||||||
|
|
||||||
|
def clean_up(self):
|
||||||
|
temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*')
|
||||||
|
for temp_study_file in temp_study_files:
|
||||||
|
temp_study_file.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
|
||||||
|
self._write_to_disk(study, self.tmp_study_path(trial.number))
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.save_final()
|
||||||
|
self.clean_up()
|
||||||
|
|
||||||
|
|
||||||
def optimize(trial: optuna.Trial):
|
def optimize(trial: optuna.Trial):
|
||||||
# Optuna configuration
|
# Optuna configuration
|
||||||
folder = Path('study')
|
folder = Path('study')
|
||||||
@@ -19,10 +70,10 @@ def optimize(trial: optuna.Trial):
|
|||||||
lr_scheduler_parameter = None
|
lr_scheduler_parameter = None
|
||||||
|
|
||||||
optuna_suggestions = dict(
|
optuna_suggestions = dict(
|
||||||
model_name='CNNBaseline',
|
model_name='VisualTransformer',
|
||||||
data_name='MaskLibrosaDatamodule',
|
data_name='CCSLibrosaDatamodule',
|
||||||
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
||||||
max_epochs=75,
|
max_epochs=200,
|
||||||
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
||||||
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||||
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||||
@@ -44,7 +95,7 @@ def optimize(trial: optuna.Trial):
|
|||||||
for layer_idx in range(model_depth):
|
for layer_idx in range(model_depth):
|
||||||
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
||||||
optuna_suggestions.update(filters=filters)
|
optuna_suggestions.update(filters=filters)
|
||||||
elif optuna_suggestions['model_name'] == 'VisualTransformer':
|
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
|
||||||
transformer_dict = dict(
|
transformer_dict = dict(
|
||||||
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||||
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||||
@@ -62,8 +113,10 @@ def optimize(trial: optuna.Trial):
|
|||||||
'_parameters.ini', overrides=optuna_suggestions)
|
'_parameters.ini', overrides=optuna_suggestions)
|
||||||
h_params = Namespace(**h_params)
|
h_params = Namespace(**h_params)
|
||||||
try:
|
try:
|
||||||
best_score = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
||||||
additional_callbacks=pruning_callback, seed=seed)
|
additional_callbacks=pruning_callback, seed=seed)
|
||||||
|
best_score = results.best_model_score
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
best_score = 0
|
best_score = 0
|
||||||
@@ -71,22 +124,18 @@ def optimize(trial: optuna.Trial):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
optuna_study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
||||||
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
with ContiniousSavingCallback('study', optuna_study) as continious_save_callback:
|
||||||
study.optimize(optimize, n_trials=100)
|
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
||||||
|
optuna_study.optimize(optimize, n_trials=200, show_progress_bar=True, callbacks=[continious_save_callback])
|
||||||
|
|
||||||
print("Number of finished trials: {}".format(len(study.trials)))
|
print("Number of finished trials: {}".format(len(optuna_study.trials)))
|
||||||
|
|
||||||
print("Best trial:")
|
print("Best trial:")
|
||||||
trial = study.best_trial
|
trial = optuna_study.best_trial
|
||||||
|
|
||||||
print(" Value: {}".format(trial.value))
|
print(" Value: {}".format(trial.value))
|
||||||
|
|
||||||
print(" Params: ")
|
print(" Params: ")
|
||||||
for key, value in trial.params.items():
|
for key, value in trial.params.items():
|
||||||
print(" {}: {}".format(key, value))
|
print(" {}: {}".format(key, value))
|
||||||
|
|
||||||
optuna_study_object = Path('study') / 'study.pkl'
|
|
||||||
optuna_study_object.parent.mkdir(exist_ok=True)
|
|
||||||
with optuna_study_object.open(mode='wb') as f:
|
|
||||||
pickle.dump(study, f)
|
|
||||||
|
|||||||
+1
-1
@@ -23,7 +23,7 @@ def rebuild_dataset(h_params, data_class):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
for dataset in [MaskLibrosaDatamodule]: # [PrimatesLibrosaDatamodule, CCSLibrosaDatamodule]:
|
for dataset in [CCSLibrosaDatamodule]:
|
||||||
# Parse comandline args, read config and get model
|
# Parse comandline args, read config and get model
|
||||||
cmd_args, _, _, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
cmd_args, _, _, _ = parse_comandline_args_add_defaults('_parameters.ini')
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,341 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% IMPORTS\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from collections import defaultdict\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from natsort import natsorted\n",
|
||||||
|
"from pytorch_lightning.core.saving import *\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from sklearn.manifold import TSNE\n",
|
||||||
|
"\n",
|
||||||
|
"import seaborn as sns\n",
|
||||||
|
"from matplotlib import pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from ml_lib.metrics.binary_class_classifictaion import BinaryScores\n",
|
||||||
|
"from ml_lib.utils.tools import locate_and_import_class\n",
|
||||||
|
"_ROOT = Path()\n",
|
||||||
|
"out_path = 'output'\n",
|
||||||
|
"model_name = 'VisualTransformer'\n",
|
||||||
|
"exp_name = 'VT_7899c07a4809a45c57cba58047cefb5e'\n",
|
||||||
|
"version = 'version_7'"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%M Path resolving and variables\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plt.style.use('default')\n",
|
||||||
|
"sns.set_palette('Dark2')\n",
|
||||||
|
"\n",
|
||||||
|
"tex_fonts = {\n",
|
||||||
|
" # Use LaTeX to write all text\n",
|
||||||
|
" \"text.usetex\": True,\n",
|
||||||
|
" \"font.family\": \"serif\",\n",
|
||||||
|
" # Use 10pt font in plots, to match 10pt font in document\n",
|
||||||
|
" \"axes.labelsize\": 10,\n",
|
||||||
|
" \"font.size\": 10,\n",
|
||||||
|
" # Make the legend/label fonts a little smaller\n",
|
||||||
|
" \"legend.fontsize\": 8,\n",
|
||||||
|
" \"xtick.labelsize\": 8,\n",
|
||||||
|
" \"ytick.labelsize\": 8\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"# plt.rcParams.update(tex_fonts)\n",
|
||||||
|
"\n",
|
||||||
|
"Path('figures').mkdir(exist_ok=True)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Seaborn Settings\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def reconstruct_model_data_params(yaml_file_path: str):\n",
|
||||||
|
" hparams_dict = load_hparams_from_yaml(yaml_file_path)\n",
|
||||||
|
"\n",
|
||||||
|
" # Try to get model_name and data_name from yaml:\n",
|
||||||
|
" model_name = hparams_dict['model_name']\n",
|
||||||
|
" data_name = hparams_dict['data_name']\n",
|
||||||
|
" # Try to find the original model and data class by name:\n",
|
||||||
|
" found_data_class = locate_and_import_class(data_name, 'datasets')\n",
|
||||||
|
" found_model_class = locate_and_import_class(model_name, 'models')\n",
|
||||||
|
" # Possible way of automatic loading args:\n",
|
||||||
|
" # args = inspect.signature(found_data_class)\n",
|
||||||
|
" # then access _parameter.ini and retrieve not set parameters\n",
|
||||||
|
"\n",
|
||||||
|
" hparams_dict.update(target_mel_length_in_seconds=1, num_worker=10, data_root='data')\n",
|
||||||
|
"\n",
|
||||||
|
" h_params = Namespace(**hparams_dict)\n",
|
||||||
|
"\n",
|
||||||
|
" # Let Datamodule pull what it wants\n",
|
||||||
|
" datamodule = found_data_class.from_argparse_args(h_params)\n",
|
||||||
|
"\n",
|
||||||
|
" hparams_dict.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes, variable_length=False)\n",
|
||||||
|
"\n",
|
||||||
|
" return datamodule, found_model_class, hparams_dict\n",
|
||||||
|
"\n",
|
||||||
|
"def gather_predictions_and_labels(model, data_option):\n",
|
||||||
|
" preds = list()\n",
|
||||||
|
" labels = list()\n",
|
||||||
|
" filenames = list()\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for file_name, x, y in datamodule.datasets[data_option]:\n",
|
||||||
|
" preds.append(model(x.unsqueeze(0)).main_out)\n",
|
||||||
|
" labels.append(y)\n",
|
||||||
|
" filenames.append(file_name)\n",
|
||||||
|
" labels = np.stack(labels).squeeze()\n",
|
||||||
|
" preds = np.stack(preds).squeeze()\n",
|
||||||
|
" return preds, labels, filenames\n",
|
||||||
|
"\n",
|
||||||
|
"def build_tsne_dataframe(preds, labels):\n",
|
||||||
|
" tsne = np.stack(TSNE().fit_transform(preds)).squeeze()\n",
|
||||||
|
" tsne_dataframe = pd.DataFrame(data=tsne, columns=['x', 'y'])\n",
|
||||||
|
"\n",
|
||||||
|
" tsne_dataframe['labels'] = labels\n",
|
||||||
|
" tsne_dataframe['labels'] = tsne_dataframe['labels'].map({val: key for key, val in datamodule.class_names.items()})\n",
|
||||||
|
" return tsne_dataframe\n",
|
||||||
|
"\n",
|
||||||
|
"def plot_scatterplot(data, option):\n",
|
||||||
|
" p = sns.scatterplot(data=data, x='x', y='y', hue='labels', legend=True)\n",
|
||||||
|
" p.set_title(f'TSNE - distribution of logits for {option}')\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"def redo_predictions(experiment_path, preds, fnames, data_class):\n",
|
||||||
|
" sorted_y = defaultdict(list)\n",
|
||||||
|
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
||||||
|
" sorted_y[fname].append(pred)\n",
|
||||||
|
" sorted_y = dict(sorted_y)\n",
|
||||||
|
"\n",
|
||||||
|
" for file_name in sorted_y:\n",
|
||||||
|
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" if data_class.n_classes > 2:\n",
|
||||||
|
" pred = np.stack(\n",
|
||||||
|
" [np.argmax(x.mean(axis=0)) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
||||||
|
" ).squeeze()\n",
|
||||||
|
" class_names = {val: key for val, key in\n",
|
||||||
|
" enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}\n",
|
||||||
|
" else:\n",
|
||||||
|
" pred = [x.mean(axis=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]\n",
|
||||||
|
" pred = np.stack(pred).squeeze()\n",
|
||||||
|
" pred = np.where(pred > 0.5, 1, 0)\n",
|
||||||
|
" class_names = {val: key for val, key in enumerate(['negative', 'positive'])}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],\n",
|
||||||
|
" prediction=[class_names[x.item()] for x in pred]))\n",
|
||||||
|
" result_file = Path(experiment_path / 'predictions_new.csv')\n",
|
||||||
|
" if result_file.exists():\n",
|
||||||
|
" try:\n",
|
||||||
|
" result_file.unlink()\n",
|
||||||
|
" except:\n",
|
||||||
|
" print('File already existed')\n",
|
||||||
|
" pass\n",
|
||||||
|
" with result_file.open(mode='wb') as csv_file:\n",
|
||||||
|
" df.to_csv(index=False, path_or_buf=csv_file)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def re_valida(preds, labels, fnames, data_class):\n",
|
||||||
|
" sorted_y = defaultdict(list)\n",
|
||||||
|
" for idx, (pred, fname) in enumerate(zip(preds, fnames)):\n",
|
||||||
|
" sorted_y[fname].append(pred)\n",
|
||||||
|
" sorted_y = dict(sorted_y)\n",
|
||||||
|
"\n",
|
||||||
|
" for file_name in sorted_y:\n",
|
||||||
|
" sorted_y.update({file_name: np.stack(sorted_y[file_name])})\n",
|
||||||
|
"\n",
|
||||||
|
" for key, val in list(sorted_y.items()):\n",
|
||||||
|
" if val.ndim > 1:\n",
|
||||||
|
" val = val.mean(axis=0)\n",
|
||||||
|
" print(val.ndim)\n",
|
||||||
|
" if not val[0] > 0.8:\n",
|
||||||
|
" val[0] = 0\n",
|
||||||
|
" sorted_y[key] = val\n",
|
||||||
|
"\n",
|
||||||
|
" pred = np.stack(\n",
|
||||||
|
" [np.argmax(x) if x.shape[0] > 1 else np.argmax(x) for x in sorted_y.values()]\n",
|
||||||
|
" ).squeeze()\n",
|
||||||
|
"\n",
|
||||||
|
" one_hot_targets = np.eye(data_class.n_classes)[pred]\n",
|
||||||
|
"\n",
|
||||||
|
" # Sklearn Scores\n",
|
||||||
|
" print(BinaryScores(dict(y=one_hot_targets, batch_y=labels)))\n"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% Utility Functions\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Selected Checkopint is output\\VisualTransformer\\VT_7899c07a4809a45c57cba58047cefb5e\\version_7\\ckpt_weights-v1.ckpt\n",
|
||||||
|
"PrimatesLibrosaDatamodule\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"exp_path = _ROOT / out_path / model_name / exp_name / version\n",
|
||||||
|
"checkpoint = natsorted(exp_path.glob('*.ckpt'))[-4]\n",
|
||||||
|
"print(f'Selected Checkopint is {checkpoint}')\n",
|
||||||
|
"hparams_yaml = next(exp_path.glob('*.yaml'))\n",
|
||||||
|
"print(load_hparams_from_yaml(hparams_yaml)['data_name'])\n",
|
||||||
|
"# LADE DAS MODELL HIER VON HAND AUS DER KLASSE DIE ABGELEGT WURDE\n",
|
||||||
|
"datamodule, model_class, h_params = reconstruct_model_data_params(hparams_yaml.__str__())\n",
|
||||||
|
"# h_params.update(return_logits=True)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "RuntimeError",
|
||||||
|
"evalue": "Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30]).",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||||
|
"\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)",
|
||||||
|
"\u001B[1;32m<ipython-input-6-c8e208607217>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel_class\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_from_checkpoint\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mh_params\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0meval\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 2\u001B[0m \u001B[0mdatamodule\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mprepare_data\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36mload_from_checkpoint\u001B[1;34m(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)\u001B[0m\n\u001B[0;32m 154\u001B[0m \u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mCHECKPOINT_HYPER_PARAMS_KEY\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 155\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 156\u001B[1;33m \u001B[0mmodel\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mcls\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_load_model_state\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 157\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 158\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\pytorch_lightning\\core\\saving.py\u001B[0m in \u001B[0;36m_load_model_state\u001B[1;34m(cls, checkpoint, strict, **cls_kwargs_new)\u001B[0m\n\u001B[0;32m 202\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 203\u001B[0m \u001B[1;31m# load the state_dict on the model automatically\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 204\u001B[1;33m \u001B[0mmodel\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mload_state_dict\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcheckpoint\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;34m'state_dict'\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstrict\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mstrict\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 205\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 206\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;32mc:\\users\\steff\\envs\\compare_21\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36mload_state_dict\u001B[1;34m(self, state_dict, strict)\u001B[0m\n\u001B[0;32m 1049\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1050\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mlen\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0merror_msgs\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m0\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1051\u001B[1;33m raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001B[0m\u001B[0;32m 1052\u001B[0m self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001B[0;32m 1053\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0m_IncompatibleKeys\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmissing_keys\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0munexpected_keys\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
|
||||||
|
"\u001B[1;31mRuntimeError\u001B[0m: Error(s) in loading state_dict for VisualTransformer:\n\tsize mismatch for pos_embedding: copying a param with shape torch.Size([1, 67, 30]) from checkpoint, the shape in current model is torch.Size([1, 122, 30])."
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model = model_class.load_from_checkpoint(checkpoint, **h_params).eval()\n",
|
||||||
|
"datamodule.prepare_data()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"predictions, labels_y, filenames = gather_predictions_and_labels(model, 'devel')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# tsne_dataframe = build_tsne_dataframe(predictions, labels_y)\n",
|
||||||
|
"# plot_scatterplot(tsne_dataframe, data_option)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"re_valida(predictions,labels_y, filenames, datamodule)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
||||||
+17
-10
@@ -5,6 +5,7 @@ from abc import ABC
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from util.loss_mixin import LossMixin
|
from util.loss_mixin import LossMixin
|
||||||
@@ -59,7 +60,7 @@ class ValMixin:
|
|||||||
|
|
||||||
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
target_y = torch.stack(tuple(sorted_batch_y.values())).long()
|
||||||
if self.params.n_classes <= 2:
|
if self.params.n_classes <= 2:
|
||||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()])
|
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()])
|
||||||
self.metrics.update(mean_sorted_y, target_y)
|
self.metrics.update(mean_sorted_y, target_y)
|
||||||
else:
|
else:
|
||||||
y_max = torch.stack(
|
y_max = torch.stack(
|
||||||
@@ -97,8 +98,11 @@ class ValMixin:
|
|||||||
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
sorted_y.update({file_name: torch.stack(sorted_y[file_name])})
|
||||||
|
|
||||||
if self.params.n_classes <= 2:
|
if self.params.n_classes <= 2:
|
||||||
mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||||
|
mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1)
|
||||||
|
# mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1)
|
||||||
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())
|
||||||
|
|
||||||
# Sklearn Scores
|
# Sklearn Scores
|
||||||
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))
|
||||||
|
|
||||||
@@ -129,6 +133,7 @@ class ValMixin:
|
|||||||
|
|
||||||
for name, image in pl_images.items():
|
for name, image in pl_images.items():
|
||||||
self.logger.log_image(name, image, step=self.global_step)
|
self.logger.log_image(name, image, step=self.global_step)
|
||||||
|
plt.close(image)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -162,12 +167,13 @@ class TestMixin:
|
|||||||
class_names = {val: key for val, key in
|
class_names = {val: key for val, key in
|
||||||
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
|
||||||
else:
|
else:
|
||||||
pred = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x for x in sorted_y.values()]).squeeze()
|
pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
|
||||||
|
pred = torch.stack(pred).squeeze()
|
||||||
pred = torch.where(pred > 0.5, 1, 0)
|
pred = torch.where(pred > 0.5, 1, 0)
|
||||||
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
class_names = {val: key for val, key in enumerate(['negative', 'positive'])}
|
||||||
|
|
||||||
|
|
||||||
df = pd.DataFrame(data=dict(filename=[Path(x).name for x in sorted_y.keys()],
|
df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
|
||||||
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
prediction=[class_names[x.item()] for x in pred.cpu()]))
|
||||||
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
result_file = Path(self.logger.log_dir / 'predictions.csv')
|
||||||
if result_file.exists():
|
if result_file.exists():
|
||||||
@@ -178,12 +184,13 @@ class TestMixin:
|
|||||||
pass
|
pass
|
||||||
with result_file.open(mode='wb') as csv_file:
|
with result_file.open(mode='wb') as csv_file:
|
||||||
df.to_csv(index=False, path_or_buf=csv_file)
|
df.to_csv(index=False, path_or_buf=csv_file)
|
||||||
with result_file.open(mode='rb') as csv_file:
|
if False:
|
||||||
try:
|
with result_file.open(mode='rb') as csv_file:
|
||||||
self.logger.neptunelogger.log_artifact(csv_file)
|
try:
|
||||||
except:
|
self.logger.neptunelogger.log_artifact(csv_file)
|
||||||
print('No possible to send to neptune')
|
except:
|
||||||
pass
|
print('No possible to send to neptune')
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CombinedModelMixins(LossMixin,
|
class CombinedModelMixins(LossMixin,
|
||||||
|
|||||||
@@ -5,5 +5,6 @@ sr = 16000
|
|||||||
|
|
||||||
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
PRIMATES_Root = Path(__file__).parent / 'data' / 'primates'
|
||||||
CCS_Root = Path(__file__).parent / 'data' / 'ComParE2021_CCS'
|
CCS_Root = Path(__file__).parent / 'data' / 'ComParE2021_CCS'
|
||||||
|
MASK_Root = Path(__file__).parent / 'data' / 'ComParE2020_Mask'
|
||||||
|
|
||||||
N_CLASS_multi = 5
|
N_CLASS_multi = 5
|
||||||
|
|||||||
Reference in New Issue
Block a user