Final Train Runs
This commit is contained in:
@@ -15,7 +15,9 @@ class CNNBaseline(CombinedModelMixins,
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation, use_bias, use_norm, dropout, lat_dim, features,
|
||||
filters, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval, loss):
|
||||
filters,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
|
||||
loss, scheduler):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access....
|
||||
a = dict(locals())
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from performer_pytorch import Performer
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
|
||||
class VisualPerformer(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(VisualPerformer, self).__init__(params)
|
||||
|
||||
self.in_shape = in_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
|
||||
# Automatic Image Shaping
|
||||
self.patch_size = self.params.patch_size
|
||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
patch_dim = channels * self.patch_size ** 2
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
|
||||
# Modules with Parameters
|
||||
self.performer = Performer(
|
||||
dim=self.embed_dim, # dimension
|
||||
depth=self.params.attn_depth, # layers
|
||||
heads=self.params.heads, # heads
|
||||
causal=True, # auto-regressive or not
|
||||
nb_features=None, # 256, # number of random features, if not set, will default to
|
||||
# (d * log(d)), where d is the dimension of each head
|
||||
feature_redraw_interval=1000, # how frequently to redraw the projection matrix,
|
||||
# the more frequent, the slower the training
|
||||
generalized_attention=False, # defaults to softmax approximation,
|
||||
# but can be set to True for generalized attention
|
||||
kernel_fn=self.params.activation(), # the kernel function to be used,
|
||||
# if generalized attention is turned on, defaults to Relu
|
||||
reversible=True, # reversible layers, from Reformer paper
|
||||
ff_chunks=10, # chunk feedforward layer, from Reformer paper
|
||||
use_scalenorm=False, # use scale norm, from 'Transformers without Tears' paper
|
||||
use_rezero=False, # use rezero, from 'Rezero is all you need' paper
|
||||
ff_glu=True, # use GLU variant for feedforward
|
||||
ff_dropout=self.params.dropout, # feedforward dropout
|
||||
attn_dropout=self.params.dropout, # post-attn dropout
|
||||
local_attn_heads=self.params.heads // 2, # 4 heads are local attention, 4 others are global performers
|
||||
local_window_size=(patch_dim // self.params.heads) * 2 # window size of local attention
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
p = self.params.patch_size
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
tensor = self.performer(tensor)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
120
models/testing.py
Normal file
120
models/testing.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
|
||||
class Tester(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(Tester, self).__init__(params)
|
||||
|
||||
self.in_shape = in_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
|
||||
# Automatic Image Shaping
|
||||
self.patch_size = self.params.patch_size
|
||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
patch_dim = channels * self.patch_size ** 2
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, n_classes),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
p = self.params.patch_size
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
if return_attn_weights:
|
||||
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
||||
else:
|
||||
attn_weights = None
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
@@ -21,9 +21,9 @@ class VisualTransformer(CombinedModelMixins,
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size,use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, loss,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim,
|
||||
lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
@@ -53,26 +53,26 @@ class VisualTransformer(CombinedModelMixins,
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
#if not self.embed_dim % self.params.heads == 0:
|
||||
# self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
# message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
# f' of attention heads, is now: {self.embed_dim}')
|
||||
# for func in print, warnings.warn:
|
||||
# func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim,
|
||||
head_dim=self.params.head_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation, use_residual=self.params.use_residual
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
|
||||
@@ -117,4 +117,4 @@ class VisualTransformer(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
@@ -21,8 +21,8 @@ class HorizontalVisualTransformer(CombinedModelMixins,
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size,use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
|
||||
116
models/transformer_model_vertical.py
Normal file
116
models/transformer_model_vertical.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import inspect
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from einops import repeat
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import CombinedModelMixins
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
|
||||
class VerticalVisualTransformer(CombinedModelMixins, LightningBaseModule):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
embedding_size, heads, attn_depth, patch_size, use_residual,
|
||||
use_bias, use_norm, dropout, lat_dim, features, loss, scheduler,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
|
||||
|
||||
# TODO: Move this to parent class, or make it much easieer to access... But How...
|
||||
a = dict(locals())
|
||||
params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
|
||||
super(VerticalVisualTransformer, self).__init__(params)
|
||||
|
||||
self.in_shape = in_shape
|
||||
self.n_classes = n_classes
|
||||
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
self.height = height
|
||||
self.channels = channels
|
||||
|
||||
self.new_width = ((width - self.params.patch_size)//1) + 1
|
||||
|
||||
num_patches = self.new_width - (self.params.patch_size // 2)
|
||||
patch_dim = channels * self.params.patch_size * self.height
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.params.patch_size),
|
||||
keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.lat_dim,
|
||||
heads=self.params.heads, depth=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation
|
||||
)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, self.n_classes),
|
||||
nn.Softmax()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:param return_attn_weights: wether to return the attn weights (optional)
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
tensor = self.slider(tensor)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
if return_attn_weights:
|
||||
tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights)
|
||||
else:
|
||||
attn_weights = None
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
Reference in New Issue
Block a user