New Model, Many Changes

This commit is contained in:
Si11ium 2020-11-21 09:28:26 +01:00
parent 7bac9e984b
commit be097a111a
12 changed files with 349 additions and 125 deletions

View File

@ -29,35 +29,44 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
# Transformation Parameters
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
# Model Parameters
main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="")
# General
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
main_arg_parser.add_argument("--model_features", type=int, default=64, help="")
# CNN Specific
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
# Transformer Specific
main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="")
main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="")
main_arg_parser.add_argument("--model_heads", type=int, default=8, help="")
main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="")
# Training Parameters
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Project Parameters

95
datasets/urban_8k.py Normal file
View File

@ -0,0 +1,95 @@
import pickle
from collections import defaultdict
from pathlib import Path
import librosa as librosa
from torch.utils.data import Dataset
import torch
import variables as V
from ml_lib.modules.util import F_x
class BinaryMasksDataset(Dataset):
_to_label = defaultdict(lambda: -1)
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
@property
def sample_shape(self):
return self[0][0].shape
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
use_preprocessed=True):
self.use_preprocessed = use_preprocessed
self.stretch = stretch_dataset
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
super(BinaryMasksDataset, self).__init__()
self.data_root = Path(data_root)
self.setting = setting
self._wav_folder = self.data_root / 'wav'
self._mel_folder = self.data_root / 'mel'
self.container_ext = '.pik'
self._mel_transform = mel_transforms
self._labels = self._build_labels()
self._wav_files = list(sorted(self._labels.keys()))
self._transforms = transforms or F_x(in_shape=None)
def _build_labels(self):
labeldict = dict()
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
# Exclude the header
_ = next(f)
for row in f:
if self.setting not in row:
continue
filename, label = row.strip().split(',')
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
if self.stretch and self.setting == V.DATA_OPTIONS.train:
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
labeldict.update(additional_dict)
# Delete File if one exists.
if not self.use_preprocessed:
for key in labeldict.keys():
try:
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
except FileNotFoundError:
pass
return labeldict
def __len__(self):
return len(self._labels)
def _compute_or_retrieve(self, filename):
if not (self._mel_folder / (filename + self.container_ext)).exists():
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
mel_sample = self._mel_transform(raw_sample)
self._mel_folder.mkdir(exist_ok=True, parents=True)
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
mel_sample = pickle.load(f, fix_imports=True)
return mel_sample
def __getitem__(self, item):
key: str = list(self._labels.keys())[item]
filename = key.replace('.wav', '')
mel_sample = self._compute_or_retrieve(filename)
label = self._labels[key]
transformed_samples = self._transforms(mel_sample)
if self.setting != V.DATA_OPTIONS.test:
# In test, filenames instead of labels are returned. This is a little hacky though.
label = torch.as_tensor(label, dtype=torch.float)
return transformed_samples, label

36
main.py
View File

@ -6,14 +6,13 @@ import warnings
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.config import Config
from ml_lib.utils.logging import Logger
# Project Specific Logger SubClasses
from util.config import MConfig
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -37,35 +36,30 @@ def run_lightning_loop(config_obj):
# Callbacks
# =============================================================================
# Checkpoint Saving
checkpoint_callback = ModelCheckpoint(
monitor='uar_score',
ckpt_callback = ModelCheckpoint(
monitor='mean_loss',
filepath=str(logger.log_dir / 'ckpt_weights'),
verbose=False,
save_top_k=5,
)
# Early Stopping
# TODO: For This to work, set a validation step and End Eval and Score
early_stopping_callback = EarlyStopping(
monitor='uar_score',
min_delta=0.01,
patience=10,
)
# Learning Rate Logger
lr_logger = LearningRateMonitor(logging_interval='epoch')
# Trainer
# =============================================================================
trainer = Trainer(max_epochs=config_obj.train.epochs,
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
check_val_every_n_epoch=10,
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
checkpoint_callback=True,
callbacks=[lr_logger, ckpt_callback],
logger=logger,
fast_dev_run=config_obj.main.debug,
early_stop_callback=None
auto_lr_find=not config_obj.main.debug
)
# Model
@ -78,10 +72,15 @@ def run_lightning_loop(config_obj):
# Train It
if config_obj.model.type.lower() != 'ensemble':
if not config_obj.main.debug and not config_obj.train.lr:
trainer.tune(model)
# ToDo: LR Finder Plot
# fig = lr_finder.plot(suggest=True)
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
model.save_to_disk(logger.log_dir)
# Evaluate It
@ -99,8 +98,7 @@ def run_lightning_loop(config_obj):
outputs.append(
model.validation_step((batch_x, label), idx, 1)
)
summary_dict = model.validation_epoch_end([outputs])
print(summary_dict['log']['uar_score'])
model.validation_epoch_end([outputs])
# trainer.test()
outpath = Path(config_obj.train.outpath)
@ -132,6 +130,6 @@ if __name__ == "__main__":
from _paramters import main_arg_parser
config = MConfig.read_argparser(main_arg_parser)
config = Config.read_argparser(main_arg_parser)
fix_all_random_seeds(config)
trained_model = run_lightning_loop(config)

View File

@ -15,11 +15,10 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
# Transforms
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
from ml_lib.utils.logging import Logger
from ml_lib.utils.config import Config
from ml_lib.utils.model_io import SavedLightningModels
from ml_lib.utils.transforms import ToTensor
from ml_lib.visualization.tools import Plotter
from util.config import MConfig
# Datasets
from datasets.binar_masks import BinaryMasksDataset
@ -66,8 +65,8 @@ if __name__ == '__main__':
config_filename = 'config.ini'
inference_out = 'manual_test_out.csv'
config = MConfig()
config.read_file((Path(model_path) / config_filename).open('r'))
config = Config()
config.read_file((Path(model_path) / config_filename).open())
test_dataloader = prepare_dataloader(config)
loaded_model = restore_logger_and_model(model_path)

View File

@ -4,7 +4,7 @@ from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
@ -33,7 +33,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
# Modules
# =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.split = Splitter(self.in_shape, self.n_band_sections)
k = 3
self.band_list = ModuleList()
@ -48,7 +48,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
# last_shape = self.conv_list[-1].shape
self.band_list.append(conv_list)
self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections)
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)

View File

@ -5,7 +5,7 @@ from torch import nn
from torch.nn import ModuleList
from ml_lib.modules.blocks import ConvModule, LinearModule
from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter)
from ml_lib.modules.util import (LightningBaseModule, Splitter)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
@ -69,7 +69,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
# Modules
# =============================================================================
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
self.split = Splitter(self.in_shape, self.n_band_sections)
self.band_list = ModuleList()
for band in range(self.n_band_sections):

View File

@ -1,16 +1,19 @@
import variables as V
from argparse import Namespace
import warnings
import torch
from torch import nn
from einops import rearrange, repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
MIN_NUM_PATCHES = 16
class VisualTransformer(BinaryMaskDatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
@ -22,69 +25,83 @@ class VisualTransformer(BinaryMaskDatasetMixin,
def __init__(self, hparams):
super(VisualTransformer, self).__init__(hparams)
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Automatic Image Shaping
image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size
self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.params.patch_size) ** 2
patch_dim = channels * self.params.patch_size ** 2
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \
f'attention. Try decreasing your patch size'
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.attention_dim = self.params.features
self.embed_dim = self.params.embedding_size
# Automatic Image Shaping
self.patch_size = self.params.patch_size
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size ** 2
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
# Modules with Parameters
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False)
self.embedding = nn.Linear(patch_dim, self.attention_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False)
self.dropout = nn.Dropout(self.params.dropout)
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
n_heads=self.params.heads, num_layers=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation_as_string
)
self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads,
self.params.lat_dim, self.params.dropout)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.dropout = nn.Dropout(self.params.dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.attention_dim),
nn.Linear(self.attention_dim, self.params.lat_dim),
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, V.NUM_CLASSES)
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, mask=None):
"""
:param tensor: the sequence to the encoder (required).
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
p = self.params.patch_size
# 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p
tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0]))
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
# '() n d -> b n d', b = b
cls_tokens = tensor.repeat(self.cls_token, b)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
@ -93,4 +110,4 @@ class VisualTransformer(BinaryMaskDatasetMixin,
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return tensor
return Namespace(main_out=tensor)

View File

@ -0,0 +1,114 @@
from argparse import Namespace
import warnings
import torch
from torch import nn
from einops import repeat
from ml_lib.modules.blocks import TransformerModule
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
BaseDataloadersMixin)
MIN_NUM_PATCHES = 16
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
BaseDataloadersMixin,
BaseTrainMixin,
BaseValMixin,
BaseOptimizerMixin,
LightningBaseModule
):
def __init__(self, hparams):
super(SequentialVisualTransformer, self).__init__(hparams)
# Dataset
# =============================================================================
self.dataset = self.build_dataset()
self.in_shape = self.dataset.train_dataset.sample_shape
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
channels, height, width = self.in_shape
# Model Paramters
# =============================================================================
# Additional parameters
self.embed_dim = self.params.embedding_size
self.patch_size = self.params.patch_size
self.height = height
# Automatic Image Shaping
image_size = (max(height, width) // self.patch_size) * self.patch_size
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
# This should be obsolete
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (self.image_size // self.patch_size) ** 2
patch_dim = channels * self.patch_size * self.image_size
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
f'attention. Try decreasing your patch size'
# Correct the Embedding Dim
if not self.embed_dim % self.params.heads == 0:
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
message = ('Embedding Dimension was fixed to be devideable by the number' +
f' of attention heads, is now: {self.embed_dim}')
for func in print, warnings.warn:
func(message)
# Utility Modules
self.autopad = AutoPadToShape((self.image_size, self.image_size))
self.dropout = nn.Dropout(self.params.dropout)
self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False)
# Modules with Parameters
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
n_heads=self.params.heads, num_layers=self.params.attn_depth,
dropout=self.params.dropout, use_norm=self.params.use_norm,
activation=self.params.activation_as_string
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
else F_x(self.embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(self.embed_dim),
nn.Linear(self.embed_dim, self.params.lat_dim),
nn.GELU(),
nn.Dropout(self.params.dropout),
nn.Linear(self.params.lat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, mask=None):
"""
:param x: the sequence to the encoder (required).
:param mask: the mask for the src sequence (optional).
:return:
"""
tensor = self.autopad(x)
tensor = self.slider(tensor)
tensor = self.patch_to_embedding(tensor)
b, n, _ = tensor.shape
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
cls_tokens = self.cls_token.repeat((b, 1, 1))
tensor = torch.cat((cls_tokens, tensor), dim=1)
tensor += self.pos_embedding[:, :(n + 1)]
tensor = self.dropout(tensor)
tensor = self.transformer(tensor, mask)
tensor = self.to_cls_token(tensor[:, 0])
tensor = self.mlp_head(tensor)
return Namespace(main_out=tensor)

View File

@ -1,7 +1,7 @@
import shutil
import warnings
from util.config import MConfig
from ml_lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -17,12 +17,12 @@ if __name__ == '__main__':
args = main_arg_parser.parse_args()
# Model Settings
config = MConfig().read_namespace(args)
config = Config().read_namespace(args)
arg_dict = dict()
for seed in range(0, 10):
arg_dict.update(main_seed=seed)
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
for model in ['VisualTransformer']:
arg_dict.update(model_type=model)
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,

View File

@ -42,7 +42,7 @@ msgpack-python==0.5.6
natsort==7.0.1
neptune-client==0.4.109
numba==0.49.1
numpy==1.18.4
numpy~=1.18.2
oauthlib==3.1.0
packaging==20.3
pandas==1.0.3
@ -68,7 +68,7 @@ resampy==0.2.2
retrying==1.3.3
rfc3987==1.3.8
rsa==4.0
scikit-learn==0.23.1
scikit-learn~=0.22.2.post1
scipy==1.4.1
simplejson==3.17.0
six==1.14.0
@ -91,3 +91,5 @@ webencodings==0.5.1
websocket-client==0.57.0
Werkzeug==1.0.1
xmltodict==0.12.0
einops~=0.3.0

View File

@ -1,26 +0,0 @@
from ml_lib.utils.config import Config
from models.conv_classifier import ConvClassifier
from models.bandwise_conv_classifier import BandwiseConvClassifier
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
from models.ensemble import Ensemble
from models.residual_conv_classifier import ResidualConvClassifier
from models.transformer_model import VisualTransformer
class MConfig(Config):
# TODO: There should be a way to automate this.
@property
def _model_map(self):
return dict(ConvClassifier=ConvClassifier,
CC=ConvClassifier,
BandwiseConvClassifier=BandwiseConvClassifier,
BCC=BandwiseConvClassifier,
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
BCMC=BandwiseConvMultiheadClassifier,
Ensemble=Ensemble,
E=Ensemble,
ResidualConvClassifier=ResidualConvClassifier,
RCC=ResidualConvClassifier,
ViT=VisualTransformer
)

View File

@ -8,7 +8,8 @@ import torch
import numpy as np
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA
from torchvision.transforms import Compose, RandomApply
@ -25,10 +26,23 @@ class BaseOptimizerMixin:
def configure_optimizers(self):
assert isinstance(self, LightningBaseModule)
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
optimizer_dict = dict(
# 'optimizer':optimizer, # The Optimizer
# 'lr_scheduler': scheduler, # The LR scheduler
frequency=1, # The frequency of the scheduler
interval='epoch', # The unit of the scheduler's step size
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
# 'monitor': 'mean_val_loss' # Metric to monitor
)
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
if self.params.sto_weight_avg:
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
return opt
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
optimizer_dict.update(optimizer=optimizer)
if self.params.lr_warmup_steps:
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
optimizer_dict.update(lr_scheduler=scheduler)
return optimizer_dict
def on_train_end(self):
assert isinstance(self, LightningBaseModule)
@ -54,17 +68,18 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
bce_loss = self.bce_loss(y, batch_y)
bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(loss=bce_loss)
def training_epoch_end(self, outputs):
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin:
@ -77,17 +92,17 @@ class BaseValMixin:
assert isinstance(self, LightningBaseModule)
batch_x, batch_y = batch_xy
y = self(batch_x).main_out
val_bce_loss = self.bce_loss(y, batch_y)
val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
return dict(val_bce_loss=val_bce_loss,
batch_idx=batch_idx, y=y, batch_y=batch_y)
def validation_epoch_end(self, outputs, *args, **kwargs):
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
summary_dict = dict()
for output_idx, output in enumerate(outputs):
keys = list(output[0].keys())
ident = '' if output_idx == 0 else '_train'
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
for output in output]))
for key in keys if 'loss' in key}
)
@ -101,8 +116,9 @@ class BaseValMixin:
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score)
summary_dict['log'].update({f'uar{ident}_score': uar_score})
return summary_dict
summary_dict.update({f'uar{ident}_score': uar_score})
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin: