paper preperations and notebooks, optuna callbacks

This commit is contained in:
Steffen Illium
2021-04-02 08:45:11 +02:00
parent 7c88602776
commit cec3a07d60
21 changed files with 3818 additions and 1059 deletions

View File

@@ -0,0 +1,69 @@
import inspect
from argparse import Namespace
from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
from util.module_mixins import CombinedModelMixins
class BandwiseConvClassifier(CombinedModelMixins,
LightningBaseModule
):
def __init__(self, in_shape, n_classes, weight_init, activation,
use_bias, use_norm, dropout, lat_dim, filters,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
loss, scheduler, lr_scheduler_parameter
):
# TODO: Move this to parent class, or make it much easieer to access....
a = dict(locals())
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
super(BandwiseConvClassifier, self).__init__(params)
# Model Paramters
# =============================================================================
# Additional parameters
self.n_band_sections = 8
# Modules
# =============================================================================
self.split = Splitter(in_shape, self.n_band_sections)
k = 3
self.band_list = ModuleList()
for band in range(self.n_band_sections):
last_shape = self.split.shape[band]
conv_list = ModuleList()
for conv_filters in self.params.filters:
conv_list.append(ConvModule(last_shape, conv_filters, (k, k), conv_stride=(2, 2), conv_padding=2,
**self.params.module_kwargs))
last_shape = conv_list[-1].shape
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
# last_shape = self.conv_list[-1].shape
self.band_list.append(conv_list)
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim, **self.params.module_kwargs)
# Make Decision between binary and Multiclass Classification
logits = n_classes if n_classes > 2 else 1
module_kwargs = self.params.module_kwargs
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
self.full_out = LinearModule(self.full_2.shape, logits, **module_kwargs)
def forward(self, batch, **kwargs):
tensors = self.split(batch)
for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
for conv in convs:
tensor = conv(tensor)
tensors[idx] = tensor
tensor = self.merge(tensors)
tensor = self.full_1(tensor)
tensor = self.full_2(tensor)
tensor = self.full_out(tensor)
return Namespace(main_out=tensor)

View File

@@ -3,8 +3,6 @@ from argparse import Namespace
from torch import nn
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import LinearModule
from ml_lib.modules.model_parts import CNNEncoder
from ml_lib.modules.util import (LightningBaseModule)
@@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
tensor = self.classifier(tensor)
return Namespace(main_out=tensor)
def additional_scores(self, outputs):
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)

View File

@@ -130,8 +130,6 @@ try:
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)
except ImportError: # pragma: do not provide model class
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover

View File

@@ -109,9 +109,3 @@ class Tester(CombinedModelMixins,
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=None)
def additional_scores(self, outputs):
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)

View File

@@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
logits = self.params.n_classes if self.params.n_classes > 2 else 1
if return_logits:
outbound_activation = nn.Identity()
outbound_activation = nn.Identity
else:
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
self.mlp_head = nn.Sequential(
@@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
self.params.activation(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, logits),
outbound_activation
outbound_activation()
)
def forward(self, x, mask=None, return_attn_weights=False):
@@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
if self.params.n_classes <= 2:
return BinaryScores(self)(outputs)
else:
return MultiClassScores(self)(outputs)

View File

@@ -116,5 +116,3 @@ class HorizontalVisualTransformer(CombinedModelMixins,
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)

View File

@@ -18,9 +18,10 @@ MIN_NUM_PATCHES = 16
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
def __init__(self, in_shape, n_classes, weight_init, activation,
embedding_size, heads, attn_depth, patch_size, use_residual,
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
return_logits=False):
# TODO: Move this to parent class, or make it much easieer to access... But How...
a = dict(locals())
@@ -47,14 +48,6 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.height, self.new_width))
self.dropout = nn.Dropout(self.params.dropout)
@@ -62,10 +55,11 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
head_dim=self.params.head_dim,
heads=self.params.heads, depth=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation
activation=self.params.activation, use_residual=self.params.use_residual
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
@@ -74,13 +68,17 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
logits = self.params.n_classes if self.params.n_classes > 2 else 1
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
self.params.activation(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, self.n_classes),
nn.Softmax()
nn.Linear(self.params.lat_dim, logits),
outbound_activation()
)
def forward(self, x, mask=None, return_attn_weights=False):
@@ -112,5 +110,3 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor, attn_weights=attn_weights)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)