paper preperations and notebooks, optuna callbacks
This commit is contained in:
69
models/bandwise_conv_classifier.py
Normal file
69
models/bandwise_conv_classifier.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
|
||||
class BandwiseConvClassifier(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
):
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
use_bias, use_norm, dropout, lat_dim, filters,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
|
||||
loss, scheduler, lr_scheduler_parameter
|
||||
):
|
||||
# TODO: Move this to parent class, or make it much easieer to access....
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(BandwiseConvClassifier, self).__init__(params)
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.n_band_sections = 8
|
||||
|
||||
# Modules
|
||||
# =============================================================================
|
||||
self.split = Splitter(in_shape, self.n_band_sections)
|
||||
|
||||
k = 3
|
||||
self.band_list = ModuleList()
|
||||
for band in range(self.n_band_sections):
|
||||
last_shape = self.split.shape[band]
|
||||
conv_list = ModuleList()
|
||||
for conv_filters in self.params.filters:
|
||||
conv_list.append(ConvModule(last_shape, conv_filters, (k, k), conv_stride=(2, 2), conv_padding=2,
|
||||
**self.params.module_kwargs))
|
||||
last_shape = conv_list[-1].shape
|
||||
# self.conv_list.append(ConvModule(last_shape, 1, 1, conv_stride=1, **self.params.module_kwargs))
|
||||
# last_shape = self.conv_list[-1].shape
|
||||
self.band_list.append(conv_list)
|
||||
|
||||
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||
|
||||
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||
self.full_2 = LinearModule(self.full_1.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||
|
||||
# Make Decision between binary and Multiclass Classification
|
||||
logits = n_classes if n_classes > 2 else 1
|
||||
module_kwargs = self.params.module_kwargs
|
||||
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||
self.full_out = LinearModule(self.full_2.shape, logits, **module_kwargs)
|
||||
|
||||
def forward(self, batch, **kwargs):
|
||||
tensors = self.split(batch)
|
||||
for idx, (tensor, convs) in enumerate(zip(tensors, self.band_list)):
|
||||
for conv in convs:
|
||||
tensor = conv(tensor)
|
||||
tensors[idx] = tensor
|
||||
|
||||
tensor = self.merge(tensors)
|
||||
tensor = self.full_1(tensor)
|
||||
tensor = self.full_2(tensor)
|
||||
tensor = self.full_out(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
@@ -3,8 +3,6 @@ from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import LinearModule
|
||||
from ml_lib.modules.model_parts import CNNEncoder
|
||||
from ml_lib.modules.util import (LightningBaseModule)
|
||||
@@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
|
||||
|
||||
tensor = self.classifier(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
@@ -130,8 +130,6 @@ try:
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
except ImportError: # pragma: do not provide model class
|
||||
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
||||
|
||||
@@ -109,9 +109,3 @@ class Tester(CombinedModelMixins,
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=None)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
@@ -73,9 +73,9 @@ class VisualTransformer(CombinedModelMixins,
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
if return_logits:
|
||||
outbound_activation = nn.Identity()
|
||||
outbound_activation = nn.Identity
|
||||
else:
|
||||
outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid()
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
@@ -84,7 +84,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@@ -128,8 +128,3 @@ class VisualTransformer(CombinedModelMixins,
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes <= 2:
|
||||
return BinaryScores(self)(outputs)
|
||||
else:
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
@@ -116,5 +116,3 @@ class HorizontalVisualTransformer(CombinedModelMixins,
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
@@ -18,9 +18,10 @@ MIN_NUM_PATCHES = 16
|
||||
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual, variable_length,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval,
|
||||
return_logits=False):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
@@ -47,14 +48,6 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
@@ -62,10 +55,11 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
head_dim=self.params.head_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
@@ -74,13 +68,17 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
logits = self.params.n_classes if self.params.n_classes > 2 else 1
|
||||
|
||||
outbound_activation = nn.Softmax if logits > 1 else nn.Sigmoid
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
self.params.activation(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, self.n_classes),
|
||||
nn.Softmax()
|
||||
nn.Linear(self.params.lat_dim, logits),
|
||||
outbound_activation()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
@@ -112,5 +110,3 @@ class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
Reference in New Issue
Block a user