New Model, Many Changes
This commit is contained in:
parent
7bac9e984b
commit
be097a111a
@ -29,35 +29,44 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
||||
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
||||
|
||||
# Transformation Parameters
|
||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2
|
||||
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
|
||||
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
|
||||
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
||||
|
||||
# Model Parameters
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="")
|
||||
# General
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
|
||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
|
||||
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
|
||||
main_arg_parser.add_argument("--model_features", type=int, default=64, help="")
|
||||
|
||||
# CNN Specific
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||
|
||||
# Transformer Specific
|
||||
main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="")
|
||||
main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="")
|
||||
main_arg_parser.add_argument("--model_heads", type=int, default=8, help="")
|
||||
main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="")
|
||||
|
||||
# Training Parameters
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
|
||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
|
||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
|
||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||
|
||||
# Project Parameters
|
||||
|
95
datasets/urban_8k.py
Normal file
95
datasets/urban_8k.py
Normal file
@ -0,0 +1,95 @@
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import librosa as librosa
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
import variables as V
|
||||
from ml_lib.modules.util import F_x
|
||||
|
||||
|
||||
class BinaryMasksDataset(Dataset):
|
||||
_to_label = defaultdict(lambda: -1)
|
||||
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||
use_preprocessed=True):
|
||||
self.use_preprocessed = use_preprocessed
|
||||
self.stretch = stretch_dataset
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||
super(BinaryMasksDataset, self).__init__()
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
self.setting = setting
|
||||
self._wav_folder = self.data_root / 'wav'
|
||||
self._mel_folder = self.data_root / 'mel'
|
||||
self.container_ext = '.pik'
|
||||
self._mel_transform = mel_transforms
|
||||
|
||||
self._labels = self._build_labels()
|
||||
self._wav_files = list(sorted(self._labels.keys()))
|
||||
self._transforms = transforms or F_x(in_shape=None)
|
||||
|
||||
def _build_labels(self):
|
||||
labeldict = dict()
|
||||
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
for row in f:
|
||||
if self.setting not in row:
|
||||
continue
|
||||
filename, label = row.strip().split(',')
|
||||
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
||||
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
|
||||
labeldict.update(additional_dict)
|
||||
|
||||
# Delete File if one exists.
|
||||
if not self.use_preprocessed:
|
||||
for key in labeldict.keys():
|
||||
try:
|
||||
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
return labeldict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._labels)
|
||||
|
||||
def _compute_or_retrieve(self, filename):
|
||||
|
||||
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
||||
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
|
||||
mel_sample = self._mel_transform(raw_sample)
|
||||
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
||||
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
|
||||
mel_sample = pickle.load(f, fix_imports=True)
|
||||
return mel_sample
|
||||
|
||||
def __getitem__(self, item):
|
||||
|
||||
key: str = list(self._labels.keys())[item]
|
||||
filename = key.replace('.wav', '')
|
||||
mel_sample = self._compute_or_retrieve(filename)
|
||||
label = self._labels[key]
|
||||
|
||||
transformed_samples = self._transforms(mel_sample)
|
||||
|
||||
if self.setting != V.DATA_OPTIONS.test:
|
||||
# In test, filenames instead of labels are returned. This is a little hacky though.
|
||||
label = torch.as_tensor(label, dtype=torch.float)
|
||||
|
||||
return transformed_samples, label
|
36
main.py
36
main.py
@ -6,14 +6,13 @@ import warnings
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||
|
||||
from ml_lib.modules.util import LightningBaseModule
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.logging import Logger
|
||||
|
||||
# Project Specific Logger SubClasses
|
||||
from util.config import MConfig
|
||||
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
@ -37,35 +36,30 @@ def run_lightning_loop(config_obj):
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor='uar_score',
|
||||
ckpt_callback = ModelCheckpoint(
|
||||
monitor='mean_loss',
|
||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||
verbose=False,
|
||||
save_top_k=5,
|
||||
)
|
||||
|
||||
# Early Stopping
|
||||
# TODO: For This to work, set a validation step and End Eval and Score
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor='uar_score',
|
||||
min_delta=0.01,
|
||||
patience=10,
|
||||
)
|
||||
# Learning Rate Logger
|
||||
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(max_epochs=config_obj.train.epochs,
|
||||
show_progress_bar=True,
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
check_val_every_n_epoch=10,
|
||||
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
checkpoint_callback=True,
|
||||
callbacks=[lr_logger, ckpt_callback],
|
||||
logger=logger,
|
||||
fast_dev_run=config_obj.main.debug,
|
||||
early_stop_callback=None
|
||||
auto_lr_find=not config_obj.main.debug
|
||||
)
|
||||
|
||||
# Model
|
||||
@ -78,10 +72,15 @@ def run_lightning_loop(config_obj):
|
||||
|
||||
# Train It
|
||||
if config_obj.model.type.lower() != 'ensemble':
|
||||
if not config_obj.main.debug and not config_obj.train.lr:
|
||||
trainer.tune(model)
|
||||
# ToDo: LR Finder Plot
|
||||
# fig = lr_finder.plot(suggest=True)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
||||
model.save_to_disk(logger.log_dir)
|
||||
|
||||
# Evaluate It
|
||||
@ -99,8 +98,7 @@ def run_lightning_loop(config_obj):
|
||||
outputs.append(
|
||||
model.validation_step((batch_x, label), idx, 1)
|
||||
)
|
||||
summary_dict = model.validation_epoch_end([outputs])
|
||||
print(summary_dict['log']['uar_score'])
|
||||
model.validation_epoch_end([outputs])
|
||||
|
||||
# trainer.test()
|
||||
outpath = Path(config_obj.train.outpath)
|
||||
@ -132,6 +130,6 @@ if __name__ == "__main__":
|
||||
|
||||
from _paramters import main_arg_parser
|
||||
|
||||
config = MConfig.read_argparser(main_arg_parser)
|
||||
config = Config.read_argparser(main_arg_parser)
|
||||
fix_all_random_seeds(config)
|
||||
trained_model = run_lightning_loop(config)
|
||||
|
@ -15,11 +15,10 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
|
||||
|
||||
# Transforms
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.utils.logging import Logger
|
||||
from ml_lib.utils.config import Config
|
||||
from ml_lib.utils.model_io import SavedLightningModels
|
||||
from ml_lib.utils.transforms import ToTensor
|
||||
from ml_lib.visualization.tools import Plotter
|
||||
from util.config import MConfig
|
||||
|
||||
|
||||
# Datasets
|
||||
from datasets.binar_masks import BinaryMasksDataset
|
||||
@ -66,8 +65,8 @@ if __name__ == '__main__':
|
||||
config_filename = 'config.ini'
|
||||
inference_out = 'manual_test_out.csv'
|
||||
|
||||
config = MConfig()
|
||||
config.read_file((Path(model_path) / config_filename).open('r'))
|
||||
config = Config()
|
||||
config.read_file((Path(model_path) / config_filename).open())
|
||||
test_dataloader = prepare_dataloader(config)
|
||||
|
||||
loaded_model = restore_logger_and_model(model_path)
|
||||
|
@ -4,7 +4,7 @@ from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
@ -33,7 +33,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
||||
|
||||
# Modules
|
||||
# =============================================================================
|
||||
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
|
||||
self.split = Splitter(self.in_shape, self.n_band_sections)
|
||||
|
||||
k = 3
|
||||
self.band_list = ModuleList()
|
||||
@ -48,7 +48,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
||||
# last_shape = self.conv_list[-1].shape
|
||||
self.band_list.append(conv_list)
|
||||
|
||||
self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||
|
||||
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
|
||||
|
@ -5,7 +5,7 @@ from torch import nn
|
||||
from torch.nn import ModuleList
|
||||
|
||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter)
|
||||
from ml_lib.modules.util import (LightningBaseModule, Splitter)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
@ -69,7 +69,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
|
||||
|
||||
# Modules
|
||||
# =============================================================================
|
||||
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
|
||||
self.split = Splitter(self.in_shape, self.n_band_sections)
|
||||
|
||||
self.band_list = ModuleList()
|
||||
for band in range(self.n_band_sections):
|
||||
|
@ -1,16 +1,19 @@
|
||||
import variables as V
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
|
||||
class VisualTransformer(BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
@ -22,69 +25,83 @@ class VisualTransformer(BinaryMaskDatasetMixin,
|
||||
def __init__(self, hparams):
|
||||
super(VisualTransformer, self).__init__(hparams)
|
||||
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Automatic Image Shaping
|
||||
image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size
|
||||
self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
|
||||
num_patches = (self.image_size // self.params.patch_size) ** 2
|
||||
patch_dim = channels * self.params.patch_size ** 2
|
||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset()
|
||||
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.attention_dim = self.params.features
|
||||
self.embed_dim = self.params.embedding_size
|
||||
|
||||
# Automatic Image Shaping
|
||||
self.patch_size = self.params.patch_size
|
||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
patch_dim = channels * self.patch_size ** 2
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
|
||||
# Modules with Parameters
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False)
|
||||
self.embedding = nn.Linear(patch_dim, self.attention_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False)
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation_as_string
|
||||
)
|
||||
|
||||
self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads,
|
||||
self.params.lat_dim, self.params.dropout)
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.attention_dim),
|
||||
nn.Linear(self.attention_dim, self.params.lat_dim),
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, V.NUM_CLASSES)
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
:param tensor: the sequence to the encoder (required).
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:return:
|
||||
"""
|
||||
|
||||
tensor = self.autopad(x)
|
||||
p = self.params.patch_size
|
||||
# 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p
|
||||
tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0]))
|
||||
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
# '() n d -> b n d', b = b
|
||||
cls_tokens = tensor.repeat(self.cls_token, b)
|
||||
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
@ -93,4 +110,4 @@ class VisualTransformer(BinaryMaskDatasetMixin,
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return tensor
|
||||
return Namespace(main_out=tensor)
|
||||
|
114
models/transformer_model_sequential.py
Normal file
114
models/transformer_model_sequential.py
Normal file
@ -0,0 +1,114 @@
|
||||
from argparse import Namespace
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from einops import repeat
|
||||
|
||||
from ml_lib.modules.blocks import TransformerModule
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin)
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
||||
BaseDataloadersMixin,
|
||||
BaseTrainMixin,
|
||||
BaseValMixin,
|
||||
BaseOptimizerMixin,
|
||||
LightningBaseModule
|
||||
):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(SequentialVisualTransformer, self).__init__(hparams)
|
||||
|
||||
# Dataset
|
||||
# =============================================================================
|
||||
self.dataset = self.build_dataset()
|
||||
|
||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||
channels, height, width = self.in_shape
|
||||
|
||||
# Model Paramters
|
||||
# =============================================================================
|
||||
# Additional parameters
|
||||
self.embed_dim = self.params.embedding_size
|
||||
self.patch_size = self.params.patch_size
|
||||
self.height = height
|
||||
|
||||
# Automatic Image Shaping
|
||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||
|
||||
# This should be obsolete
|
||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||
|
||||
num_patches = (self.image_size // self.patch_size) ** 2
|
||||
patch_dim = channels * self.patch_size * self.image_size
|
||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||
f'attention. Try decreasing your patch size'
|
||||
|
||||
# Correct the Embedding Dim
|
||||
if not self.embed_dim % self.params.heads == 0:
|
||||
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||
f' of attention heads, is now: {self.embed_dim}')
|
||||
for func in print, warnings.warn:
|
||||
func(message)
|
||||
|
||||
# Utility Modules
|
||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||
self.dropout = nn.Dropout(self.params.dropout)
|
||||
self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False)
|
||||
|
||||
# Modules with Parameters
|
||||
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||
activation=self.params.activation_as_string
|
||||
)
|
||||
|
||||
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||
else F_x(self.embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
self.to_cls_token = nn.Identity()
|
||||
|
||||
self.mlp_head = nn.Sequential(
|
||||
nn.LayerNorm(self.embed_dim),
|
||||
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(self.params.dropout),
|
||||
nn.Linear(self.params.lat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
:param x: the sequence to the encoder (required).
|
||||
:param mask: the mask for the src sequence (optional).
|
||||
:return:
|
||||
"""
|
||||
tensor = self.autopad(x)
|
||||
tensor = self.slider(tensor)
|
||||
|
||||
tensor = self.patch_to_embedding(tensor)
|
||||
b, n, _ = tensor.shape
|
||||
|
||||
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||
cls_tokens = self.cls_token.repeat((b, 1, 1))
|
||||
|
||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||
tensor += self.pos_embedding[:, :(n + 1)]
|
||||
tensor = self.dropout(tensor)
|
||||
|
||||
tensor = self.transformer(tensor, mask)
|
||||
|
||||
tensor = self.to_cls_token(tensor[:, 0])
|
||||
tensor = self.mlp_head(tensor)
|
||||
return Namespace(main_out=tensor)
|
@ -1,7 +1,7 @@
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
from util.config import MConfig
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
@ -17,12 +17,12 @@ if __name__ == '__main__':
|
||||
|
||||
args = main_arg_parser.parse_args()
|
||||
# Model Settings
|
||||
config = MConfig().read_namespace(args)
|
||||
config = Config().read_namespace(args)
|
||||
|
||||
arg_dict = dict()
|
||||
for seed in range(0, 10):
|
||||
arg_dict.update(main_seed=seed)
|
||||
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
|
||||
for model in ['VisualTransformer']:
|
||||
arg_dict.update(model_type=model)
|
||||
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
||||
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||
|
@ -42,7 +42,7 @@ msgpack-python==0.5.6
|
||||
natsort==7.0.1
|
||||
neptune-client==0.4.109
|
||||
numba==0.49.1
|
||||
numpy==1.18.4
|
||||
numpy~=1.18.2
|
||||
oauthlib==3.1.0
|
||||
packaging==20.3
|
||||
pandas==1.0.3
|
||||
@ -68,7 +68,7 @@ resampy==0.2.2
|
||||
retrying==1.3.3
|
||||
rfc3987==1.3.8
|
||||
rsa==4.0
|
||||
scikit-learn==0.23.1
|
||||
scikit-learn~=0.22.2.post1
|
||||
scipy==1.4.1
|
||||
simplejson==3.17.0
|
||||
six==1.14.0
|
||||
@ -91,3 +91,5 @@ webencodings==0.5.1
|
||||
websocket-client==0.57.0
|
||||
Werkzeug==1.0.1
|
||||
xmltodict==0.12.0
|
||||
|
||||
einops~=0.3.0
|
@ -1,26 +0,0 @@
|
||||
from ml_lib.utils.config import Config
|
||||
from models.conv_classifier import ConvClassifier
|
||||
from models.bandwise_conv_classifier import BandwiseConvClassifier
|
||||
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
|
||||
from models.ensemble import Ensemble
|
||||
from models.residual_conv_classifier import ResidualConvClassifier
|
||||
from models.transformer_model import VisualTransformer
|
||||
|
||||
|
||||
class MConfig(Config):
|
||||
# TODO: There should be a way to automate this.
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
return dict(ConvClassifier=ConvClassifier,
|
||||
CC=ConvClassifier,
|
||||
BandwiseConvClassifier=BandwiseConvClassifier,
|
||||
BCC=BandwiseConvClassifier,
|
||||
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
|
||||
BCMC=BandwiseConvMultiheadClassifier,
|
||||
Ensemble=Ensemble,
|
||||
E=Ensemble,
|
||||
ResidualConvClassifier=ResidualConvClassifier,
|
||||
RCC=ResidualConvClassifier,
|
||||
ViT=VisualTransformer
|
||||
)
|
@ -8,7 +8,8 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||
from torch.utils.data import DataLoader
|
||||
from torchcontrib.optim import SWA
|
||||
from torchvision.transforms import Compose, RandomApply
|
||||
|
||||
@ -25,10 +26,23 @@ class BaseOptimizerMixin:
|
||||
|
||||
def configure_optimizers(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
optimizer_dict = dict(
|
||||
# 'optimizer':optimizer, # The Optimizer
|
||||
# 'lr_scheduler': scheduler, # The LR scheduler
|
||||
frequency=1, # The frequency of the scheduler
|
||||
interval='epoch', # The unit of the scheduler's step size
|
||||
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
|
||||
# 'monitor': 'mean_val_loss' # Metric to monitor
|
||||
)
|
||||
|
||||
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||
if self.params.sto_weight_avg:
|
||||
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
return opt
|
||||
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||
optimizer_dict.update(optimizer=optimizer)
|
||||
if self.params.lr_warmup_steps:
|
||||
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
|
||||
optimizer_dict.update(lr_scheduler=scheduler)
|
||||
return optimizer_dict
|
||||
|
||||
def on_train_end(self):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
@ -54,17 +68,18 @@ class BaseTrainMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
bce_loss = self.bce_loss(y, batch_y)
|
||||
bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||
return dict(loss=bce_loss)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
keys = list(outputs[0].keys())
|
||||
|
||||
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key})
|
||||
return summary_dict
|
||||
for key in keys if 'loss' in key}
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BaseValMixin:
|
||||
@ -77,17 +92,17 @@ class BaseValMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
batch_x, batch_y = batch_xy
|
||||
y = self(batch_x).main_out
|
||||
val_bce_loss = self.bce_loss(y, batch_y)
|
||||
val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||
return dict(val_bce_loss=val_bce_loss,
|
||||
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||
|
||||
def validation_epoch_end(self, outputs, *args, **kwargs):
|
||||
def validation_epoch_end(self, outputs, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
summary_dict = dict(log=dict())
|
||||
summary_dict = dict()
|
||||
for output_idx, output in enumerate(outputs):
|
||||
keys = list(output[0].keys())
|
||||
ident = '' if output_idx == 0 else '_train'
|
||||
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
||||
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in output]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
@ -101,8 +116,9 @@ class BaseValMixin:
|
||||
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
||||
sample_weight=None, zero_division='warn')
|
||||
uar_score = torch.as_tensor(uar_score)
|
||||
summary_dict['log'].update({f'uar{ident}_score': uar_score})
|
||||
return summary_dict
|
||||
summary_dict.update({f'uar{ident}_score': uar_score})
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BinaryMaskDatasetMixin:
|
||||
|
Loading…
x
Reference in New Issue
Block a user