New Model, Many Changes
This commit is contained in:
parent
7bac9e984b
commit
be097a111a
@ -29,35 +29,44 @@ main_arg_parser.add_argument("--data_n_fft", type=int, default=512, help="")
|
|||||||
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--data_stretch", type=strtobool, default=True, help="")
|
||||||
|
|
||||||
# Transformation Parameters
|
# Transformation Parameters
|
||||||
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_loudness_ratio", type=float, default=0.0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.3, help="") # 0.4
|
main_arg_parser.add_argument("--data_shift_ratio", type=float, default=0.0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_noise_ratio", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0, help="") # 0.2
|
main_arg_parser.add_argument("--data_mask_ratio", type=float, default=0.3, help="") # 0.2
|
||||||
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
|
main_arg_parser.add_argument("--data_speed_amount", type=float, default=0, help="") # 0.4
|
||||||
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
|
main_arg_parser.add_argument("--data_speed_min", type=float, default=0, help="") # 0.7
|
||||||
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
main_arg_parser.add_argument("--data_speed_max", type=float, default=0, help="") # 1.7
|
||||||
|
|
||||||
# Model Parameters
|
# Model Parameters
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="ViT", help="")
|
# General
|
||||||
|
main_arg_parser.add_argument("--model_type", type=str, default="SequentialVisualTransformer", help="")
|
||||||
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
main_arg_parser.add_argument("--model_weight_init", type=str, default="xavier_normal_", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="gelu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
|
||||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
|
|
||||||
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_bias", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_norm", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.2, help="")
|
||||||
|
main_arg_parser.add_argument("--model_lat_dim", type=int, default=128, help="")
|
||||||
|
main_arg_parser.add_argument("--model_features", type=int, default=64, help="")
|
||||||
|
|
||||||
|
# CNN Specific
|
||||||
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||||
|
|
||||||
|
# Transformer Specific
|
||||||
|
main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="")
|
||||||
|
main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="")
|
||||||
|
main_arg_parser.add_argument("--model_heads", type=int, default=8, help="")
|
||||||
|
main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="")
|
||||||
|
|
||||||
# Training Parameters
|
# Training Parameters
|
||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
# FIXME: Stochastic weight Avaraging is not good, maybe its my implementation?
|
|
||||||
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--train_sto_weight_avg", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--train_weight_decay", type=float, default=1e-7, help="")
|
main_arg_parser.add_argument("--train_weight_decay", type=float, default=0, help="")
|
||||||
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_opt_reset_interval", type=int, default=0, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=51, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=100, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=300, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=250, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-4, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||||
|
main_arg_parser.add_argument("--train_lr_warmup_steps", type=int, default=10, help="")
|
||||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||||
|
|
||||||
# Project Parameters
|
# Project Parameters
|
||||||
|
95
datasets/urban_8k.py
Normal file
95
datasets/urban_8k.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import pickle
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa as librosa
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import variables as V
|
||||||
|
from ml_lib.modules.util import F_x
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryMasksDataset(Dataset):
|
||||||
|
_to_label = defaultdict(lambda: -1)
|
||||||
|
_to_label.update(dict(clear=V.CLEAR, mask=V.MASK))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_shape(self):
|
||||||
|
return self[0][0].shape
|
||||||
|
|
||||||
|
def __init__(self, data_root, setting, mel_transforms, transforms=None, stretch_dataset=False,
|
||||||
|
use_preprocessed=True):
|
||||||
|
self.use_preprocessed = use_preprocessed
|
||||||
|
self.stretch = stretch_dataset
|
||||||
|
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||||
|
assert setting in V.DATA_OPTIONS, f'Setting must match one of: {V.DATA_OPTIONS}.'
|
||||||
|
super(BinaryMasksDataset, self).__init__()
|
||||||
|
|
||||||
|
self.data_root = Path(data_root)
|
||||||
|
self.setting = setting
|
||||||
|
self._wav_folder = self.data_root / 'wav'
|
||||||
|
self._mel_folder = self.data_root / 'mel'
|
||||||
|
self.container_ext = '.pik'
|
||||||
|
self._mel_transform = mel_transforms
|
||||||
|
|
||||||
|
self._labels = self._build_labels()
|
||||||
|
self._wav_files = list(sorted(self._labels.keys()))
|
||||||
|
self._transforms = transforms or F_x(in_shape=None)
|
||||||
|
|
||||||
|
def _build_labels(self):
|
||||||
|
labeldict = dict()
|
||||||
|
with open(Path(self.data_root) / 'lab' / 'labels.csv', mode='r') as f:
|
||||||
|
# Exclude the header
|
||||||
|
_ = next(f)
|
||||||
|
for row in f:
|
||||||
|
if self.setting not in row:
|
||||||
|
continue
|
||||||
|
filename, label = row.strip().split(',')
|
||||||
|
labeldict[filename] = self._to_label[label.lower()] if not self.setting == 'test' else filename
|
||||||
|
if self.stretch and self.setting == V.DATA_OPTIONS.train:
|
||||||
|
additional_dict = ({f'X{key}': val for key, val in labeldict.items()})
|
||||||
|
additional_dict.update({f'XX{key}': val for key, val in labeldict.items()})
|
||||||
|
additional_dict.update({f'XXX{key}': val for key, val in labeldict.items()})
|
||||||
|
labeldict.update(additional_dict)
|
||||||
|
|
||||||
|
# Delete File if one exists.
|
||||||
|
if not self.use_preprocessed:
|
||||||
|
for key in labeldict.keys():
|
||||||
|
try:
|
||||||
|
(self._mel_folder / (key.replace('.wav', '') + self.container_ext)).unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return labeldict
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._labels)
|
||||||
|
|
||||||
|
def _compute_or_retrieve(self, filename):
|
||||||
|
|
||||||
|
if not (self._mel_folder / (filename + self.container_ext)).exists():
|
||||||
|
raw_sample, sr = librosa.core.load(self._wav_folder / (filename.replace('X', '') + '.wav'))
|
||||||
|
mel_sample = self._mel_transform(raw_sample)
|
||||||
|
self._mel_folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
with (self._mel_folder / (filename + self.container_ext)).open(mode='wb') as f:
|
||||||
|
pickle.dump(mel_sample, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
with (self._mel_folder / (filename + self.container_ext)).open(mode='rb') as f:
|
||||||
|
mel_sample = pickle.load(f, fix_imports=True)
|
||||||
|
return mel_sample
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
|
||||||
|
key: str = list(self._labels.keys())[item]
|
||||||
|
filename = key.replace('.wav', '')
|
||||||
|
mel_sample = self._compute_or_retrieve(filename)
|
||||||
|
label = self._labels[key]
|
||||||
|
|
||||||
|
transformed_samples = self._transforms(mel_sample)
|
||||||
|
|
||||||
|
if self.setting != V.DATA_OPTIONS.test:
|
||||||
|
# In test, filenames instead of labels are returned. This is a little hacky though.
|
||||||
|
label = torch.as_tensor(label, dtype=torch.float)
|
||||||
|
|
||||||
|
return transformed_samples, label
|
36
main.py
36
main.py
@ -6,14 +6,13 @@ import warnings
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||||
|
|
||||||
from ml_lib.modules.util import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
|
from ml_lib.utils.config import Config
|
||||||
from ml_lib.utils.logging import Logger
|
from ml_lib.utils.logging import Logger
|
||||||
|
|
||||||
# Project Specific Logger SubClasses
|
# Project Specific Logger SubClasses
|
||||||
from util.config import MConfig
|
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -37,35 +36,30 @@ def run_lightning_loop(config_obj):
|
|||||||
# Callbacks
|
# Callbacks
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Checkpoint Saving
|
# Checkpoint Saving
|
||||||
checkpoint_callback = ModelCheckpoint(
|
ckpt_callback = ModelCheckpoint(
|
||||||
monitor='uar_score',
|
monitor='mean_loss',
|
||||||
filepath=str(logger.log_dir / 'ckpt_weights'),
|
filepath=str(logger.log_dir / 'ckpt_weights'),
|
||||||
verbose=False,
|
verbose=False,
|
||||||
save_top_k=5,
|
save_top_k=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Early Stopping
|
# Learning Rate Logger
|
||||||
# TODO: For This to work, set a validation step and End Eval and Score
|
lr_logger = LearningRateMonitor(logging_interval='epoch')
|
||||||
early_stopping_callback = EarlyStopping(
|
|
||||||
monitor='uar_score',
|
|
||||||
min_delta=0.01,
|
|
||||||
patience=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trainer
|
# Trainer
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
trainer = Trainer(max_epochs=config_obj.train.epochs,
|
trainer = Trainer(max_epochs=config_obj.train.epochs,
|
||||||
show_progress_bar=True,
|
|
||||||
weights_save_path=logger.log_dir,
|
weights_save_path=logger.log_dir,
|
||||||
gpus=[0] if torch.cuda.is_available() else None,
|
gpus=[0] if torch.cuda.is_available() else None,
|
||||||
check_val_every_n_epoch=10,
|
check_val_every_n_epoch=10,
|
||||||
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=True,
|
||||||
|
callbacks=[lr_logger, ckpt_callback],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
fast_dev_run=config_obj.main.debug,
|
fast_dev_run=config_obj.main.debug,
|
||||||
early_stop_callback=None
|
auto_lr_find=not config_obj.main.debug
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
@ -78,10 +72,15 @@ def run_lightning_loop(config_obj):
|
|||||||
|
|
||||||
# Train It
|
# Train It
|
||||||
if config_obj.model.type.lower() != 'ensemble':
|
if config_obj.model.type.lower() != 'ensemble':
|
||||||
|
if not config_obj.main.debug and not config_obj.train.lr:
|
||||||
|
trainer.tune(model)
|
||||||
|
# ToDo: LR Finder Plot
|
||||||
|
# fig = lr_finder.plot(suggest=True)
|
||||||
|
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
# Save the last state & all parameters
|
# Save the last state & all parameters
|
||||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
trainer.save_checkpoint(str(logger.log_dir / 'weights.ckpt'))
|
||||||
model.save_to_disk(logger.log_dir)
|
model.save_to_disk(logger.log_dir)
|
||||||
|
|
||||||
# Evaluate It
|
# Evaluate It
|
||||||
@ -99,8 +98,7 @@ def run_lightning_loop(config_obj):
|
|||||||
outputs.append(
|
outputs.append(
|
||||||
model.validation_step((batch_x, label), idx, 1)
|
model.validation_step((batch_x, label), idx, 1)
|
||||||
)
|
)
|
||||||
summary_dict = model.validation_epoch_end([outputs])
|
model.validation_epoch_end([outputs])
|
||||||
print(summary_dict['log']['uar_score'])
|
|
||||||
|
|
||||||
# trainer.test()
|
# trainer.test()
|
||||||
outpath = Path(config_obj.train.outpath)
|
outpath = Path(config_obj.train.outpath)
|
||||||
@ -132,6 +130,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
from _paramters import main_arg_parser
|
from _paramters import main_arg_parser
|
||||||
|
|
||||||
config = MConfig.read_argparser(main_arg_parser)
|
config = Config.read_argparser(main_arg_parser)
|
||||||
fix_all_random_seeds(config)
|
fix_all_random_seeds(config)
|
||||||
trained_model = run_lightning_loop(config)
|
trained_model = run_lightning_loop(config)
|
||||||
|
@ -15,11 +15,10 @@ from ml_lib.audio_toolset.audio_io import AudioToMel, NormalizeLocal, MelToImage
|
|||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||||
from ml_lib.utils.logging import Logger
|
from ml_lib.utils.config import Config
|
||||||
from ml_lib.utils.model_io import SavedLightningModels
|
from ml_lib.utils.model_io import SavedLightningModels
|
||||||
from ml_lib.utils.transforms import ToTensor
|
from ml_lib.utils.transforms import ToTensor
|
||||||
from ml_lib.visualization.tools import Plotter
|
|
||||||
from util.config import MConfig
|
|
||||||
|
|
||||||
# Datasets
|
# Datasets
|
||||||
from datasets.binar_masks import BinaryMasksDataset
|
from datasets.binar_masks import BinaryMasksDataset
|
||||||
@ -66,8 +65,8 @@ if __name__ == '__main__':
|
|||||||
config_filename = 'config.ini'
|
config_filename = 'config.ini'
|
||||||
inference_out = 'manual_test_out.csv'
|
inference_out = 'manual_test_out.csv'
|
||||||
|
|
||||||
config = MConfig()
|
config = Config()
|
||||||
config.read_file((Path(model_path) / config_filename).open('r'))
|
config.read_file((Path(model_path) / config_filename).open())
|
||||||
test_dataloader = prepare_dataloader(config)
|
test_dataloader = prepare_dataloader(config)
|
||||||
|
|
||||||
loaded_model = restore_logger_and_model(model_path)
|
loaded_model = restore_logger_and_model(model_path)
|
||||||
|
@ -4,7 +4,7 @@ from torch import nn
|
|||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, HorizontalSplitter, HorizontalMerger)
|
from ml_lib.modules.util import (LightningBaseModule, Splitter, Merger)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
|||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
|
self.split = Splitter(self.in_shape, self.n_band_sections)
|
||||||
|
|
||||||
k = 3
|
k = 3
|
||||||
self.band_list = ModuleList()
|
self.band_list = ModuleList()
|
||||||
@ -48,7 +48,7 @@ class BandwiseConvClassifier(BinaryMaskDatasetMixin,
|
|||||||
# last_shape = self.conv_list[-1].shape
|
# last_shape = self.conv_list[-1].shape
|
||||||
self.band_list.append(conv_list)
|
self.band_list.append(conv_list)
|
||||||
|
|
||||||
self.merge = HorizontalMerger(self.band_list[-1][-1].shape, self.n_band_sections)
|
self.merge = Merger(self.band_list[-1][-1].shape, self.n_band_sections)
|
||||||
|
|
||||||
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
self.full_1 = LinearModule(self.merge.shape, self.params.lat_dim, **self.params.module_kwargs)
|
||||||
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
|
self.full_2 = LinearModule(self.full_1.shape, self.full_1.shape * 2, **self.params.module_kwargs)
|
||||||
|
@ -5,7 +5,7 @@ from torch import nn
|
|||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from ml_lib.modules.blocks import ConvModule, LinearModule
|
from ml_lib.modules.blocks import ConvModule, LinearModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, Flatten, HorizontalSplitter)
|
from ml_lib.modules.util import (LightningBaseModule, Splitter)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class BandwiseConvMultiheadClassifier(BinaryMaskDatasetMixin,
|
|||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.split = HorizontalSplitter(self.in_shape, self.n_band_sections)
|
self.split = Splitter(self.in_shape, self.n_band_sections)
|
||||||
|
|
||||||
self.band_list = ModuleList()
|
self.band_list = ModuleList()
|
||||||
for band in range(self.n_band_sections):
|
for band in range(self.n_band_sections):
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
import variables as V
|
from argparse import Namespace
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin)
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
MIN_NUM_PATCHES = 16
|
MIN_NUM_PATCHES = 16
|
||||||
|
|
||||||
|
|
||||||
class VisualTransformer(BinaryMaskDatasetMixin,
|
class VisualTransformer(BinaryMaskDatasetMixin,
|
||||||
BaseDataloadersMixin,
|
BaseDataloadersMixin,
|
||||||
BaseTrainMixin,
|
BaseTrainMixin,
|
||||||
@ -22,69 +25,83 @@ class VisualTransformer(BinaryMaskDatasetMixin,
|
|||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(VisualTransformer, self).__init__(hparams)
|
super(VisualTransformer, self).__init__(hparams)
|
||||||
|
|
||||||
self.in_shape = self.dataset.train_dataset.sample_shape
|
|
||||||
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
|
||||||
channels, height, width = self.in_shape
|
|
||||||
|
|
||||||
# Automatic Image Shaping
|
|
||||||
image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size
|
|
||||||
self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size
|
|
||||||
|
|
||||||
# This should be obsolete
|
|
||||||
assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
|
||||||
|
|
||||||
num_patches = (self.image_size // self.params.patch_size) ** 2
|
|
||||||
patch_dim = channels * self.params.patch_size ** 2
|
|
||||||
assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \
|
|
||||||
f'attention. Try decreasing your patch size'
|
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
self.dataset = self.build_dataset()
|
self.dataset = self.build_dataset()
|
||||||
|
|
||||||
|
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||||
|
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||||
|
channels, height, width = self.in_shape
|
||||||
|
|
||||||
# Model Paramters
|
# Model Paramters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
self.attention_dim = self.params.features
|
self.embed_dim = self.params.embedding_size
|
||||||
|
|
||||||
|
# Automatic Image Shaping
|
||||||
|
self.patch_size = self.params.patch_size
|
||||||
|
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||||
|
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||||
|
|
||||||
|
# This should be obsolete
|
||||||
|
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||||
|
|
||||||
|
num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
patch_dim = channels * self.patch_size ** 2
|
||||||
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||||
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
|
# Correct the Embedding Dim
|
||||||
|
if not self.embed_dim % self.params.heads == 0:
|
||||||
|
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||||
|
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||||
|
f' of attention heads, is now: {self.embed_dim}')
|
||||||
|
for func in print, warnings.warn:
|
||||||
|
func(message)
|
||||||
|
|
||||||
# Utility Modules
|
# Utility Modules
|
||||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||||
|
|
||||||
# Modules with Parameters
|
# Modules with Parameters
|
||||||
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False)
|
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||||
self.embedding = nn.Linear(patch_dim, self.attention_dim)
|
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||||
self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False)
|
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||||
self.dropout = nn.Dropout(self.params.dropout)
|
activation=self.params.activation_as_string
|
||||||
|
)
|
||||||
|
|
||||||
self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads,
|
|
||||||
self.params.lat_dim, self.params.dropout)
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||||
|
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||||
|
else F_x(self.embed_dim)
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||||
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
|
|
||||||
self.to_cls_token = nn.Identity()
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
self.mlp_head = nn.Sequential(
|
self.mlp_head = nn.Sequential(
|
||||||
nn.LayerNorm(self.attention_dim),
|
nn.LayerNorm(self.embed_dim),
|
||||||
nn.Linear(self.attention_dim, self.params.lat_dim),
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(self.params.dropout),
|
nn.Dropout(self.params.dropout),
|
||||||
nn.Linear(self.params.lat_dim, V.NUM_CLASSES)
|
nn.Linear(self.params.lat_dim, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
"""
|
"""
|
||||||
:param tensor: the sequence to the encoder (required).
|
:param x: the sequence to the encoder (required).
|
||||||
:param mask: the mask for the src sequence (optional).
|
:param mask: the mask for the src sequence (optional).
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
tensor = self.autopad(x)
|
||||||
p = self.params.patch_size
|
p = self.params.patch_size
|
||||||
# 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p
|
tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
|
||||||
tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0]))
|
|
||||||
|
|
||||||
tensor = self.patch_to_embedding(tensor)
|
tensor = self.patch_to_embedding(tensor)
|
||||||
b, n, _ = tensor.shape
|
b, n, _ = tensor.shape
|
||||||
|
|
||||||
# '() n d -> b n d', b = b
|
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||||
cls_tokens = tensor.repeat(self.cls_token, b)
|
|
||||||
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||||
tensor += self.pos_embedding[:, :(n + 1)]
|
tensor += self.pos_embedding[:, :(n + 1)]
|
||||||
tensor = self.dropout(tensor)
|
tensor = self.dropout(tensor)
|
||||||
@ -93,4 +110,4 @@ class VisualTransformer(BinaryMaskDatasetMixin,
|
|||||||
|
|
||||||
tensor = self.to_cls_token(tensor[:, 0])
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
tensor = self.mlp_head(tensor)
|
tensor = self.mlp_head(tensor)
|
||||||
return tensor
|
return Namespace(main_out=tensor)
|
||||||
|
114
models/transformer_model_sequential.py
Normal file
114
models/transformer_model_sequential.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||||
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
|
BaseDataloadersMixin)
|
||||||
|
|
||||||
|
MIN_NUM_PATCHES = 16
|
||||||
|
|
||||||
|
class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
||||||
|
BaseDataloadersMixin,
|
||||||
|
BaseTrainMixin,
|
||||||
|
BaseValMixin,
|
||||||
|
BaseOptimizerMixin,
|
||||||
|
LightningBaseModule
|
||||||
|
):
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super(SequentialVisualTransformer, self).__init__(hparams)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
# =============================================================================
|
||||||
|
self.dataset = self.build_dataset()
|
||||||
|
|
||||||
|
self.in_shape = self.dataset.train_dataset.sample_shape
|
||||||
|
assert len(self.in_shape) == 3, 'There need to be three Dimensions'
|
||||||
|
channels, height, width = self.in_shape
|
||||||
|
|
||||||
|
# Model Paramters
|
||||||
|
# =============================================================================
|
||||||
|
# Additional parameters
|
||||||
|
self.embed_dim = self.params.embedding_size
|
||||||
|
self.patch_size = self.params.patch_size
|
||||||
|
self.height = height
|
||||||
|
|
||||||
|
# Automatic Image Shaping
|
||||||
|
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
||||||
|
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
||||||
|
|
||||||
|
# This should be obsolete
|
||||||
|
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
||||||
|
|
||||||
|
num_patches = (self.image_size // self.patch_size) ** 2
|
||||||
|
patch_dim = channels * self.patch_size * self.image_size
|
||||||
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||||
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
|
# Correct the Embedding Dim
|
||||||
|
if not self.embed_dim % self.params.heads == 0:
|
||||||
|
self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
|
||||||
|
message = ('Embedding Dimension was fixed to be devideable by the number' +
|
||||||
|
f' of attention heads, is now: {self.embed_dim}')
|
||||||
|
for func in print, warnings.warn:
|
||||||
|
func(message)
|
||||||
|
|
||||||
|
# Utility Modules
|
||||||
|
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
||||||
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
|
self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False)
|
||||||
|
|
||||||
|
# Modules with Parameters
|
||||||
|
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||||
|
n_heads=self.params.heads, num_layers=self.params.attn_depth,
|
||||||
|
dropout=self.params.dropout, use_norm=self.params.use_norm,
|
||||||
|
activation=self.params.activation_as_string
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
|
||||||
|
self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
|
||||||
|
else F_x(self.embed_dim)
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||||
|
self.to_cls_token = nn.Identity()
|
||||||
|
|
||||||
|
self.mlp_head = nn.Sequential(
|
||||||
|
nn.LayerNorm(self.embed_dim),
|
||||||
|
nn.Linear(self.embed_dim, self.params.lat_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(self.params.dropout),
|
||||||
|
nn.Linear(self.params.lat_dim, 1),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
:param x: the sequence to the encoder (required).
|
||||||
|
:param mask: the mask for the src sequence (optional).
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
tensor = self.autopad(x)
|
||||||
|
tensor = self.slider(tensor)
|
||||||
|
|
||||||
|
tensor = self.patch_to_embedding(tensor)
|
||||||
|
b, n, _ = tensor.shape
|
||||||
|
|
||||||
|
# cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
|
||||||
|
cls_tokens = self.cls_token.repeat((b, 1, 1))
|
||||||
|
|
||||||
|
tensor = torch.cat((cls_tokens, tensor), dim=1)
|
||||||
|
tensor += self.pos_embedding[:, :(n + 1)]
|
||||||
|
tensor = self.dropout(tensor)
|
||||||
|
|
||||||
|
tensor = self.transformer(tensor, mask)
|
||||||
|
|
||||||
|
tensor = self.to_cls_token(tensor[:, 0])
|
||||||
|
tensor = self.mlp_head(tensor)
|
||||||
|
return Namespace(main_out=tensor)
|
@ -1,7 +1,7 @@
|
|||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from util.config import MConfig
|
from ml_lib.utils.config import Config
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -17,12 +17,12 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
args = main_arg_parser.parse_args()
|
args = main_arg_parser.parse_args()
|
||||||
# Model Settings
|
# Model Settings
|
||||||
config = MConfig().read_namespace(args)
|
config = Config().read_namespace(args)
|
||||||
|
|
||||||
arg_dict = dict()
|
arg_dict = dict()
|
||||||
for seed in range(0, 10):
|
for seed in range(0, 10):
|
||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
|
for model in ['VisualTransformer']:
|
||||||
arg_dict.update(model_type=model)
|
arg_dict.update(model_type=model)
|
||||||
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
||||||
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
|
@ -42,7 +42,7 @@ msgpack-python==0.5.6
|
|||||||
natsort==7.0.1
|
natsort==7.0.1
|
||||||
neptune-client==0.4.109
|
neptune-client==0.4.109
|
||||||
numba==0.49.1
|
numba==0.49.1
|
||||||
numpy==1.18.4
|
numpy~=1.18.2
|
||||||
oauthlib==3.1.0
|
oauthlib==3.1.0
|
||||||
packaging==20.3
|
packaging==20.3
|
||||||
pandas==1.0.3
|
pandas==1.0.3
|
||||||
@ -68,7 +68,7 @@ resampy==0.2.2
|
|||||||
retrying==1.3.3
|
retrying==1.3.3
|
||||||
rfc3987==1.3.8
|
rfc3987==1.3.8
|
||||||
rsa==4.0
|
rsa==4.0
|
||||||
scikit-learn==0.23.1
|
scikit-learn~=0.22.2.post1
|
||||||
scipy==1.4.1
|
scipy==1.4.1
|
||||||
simplejson==3.17.0
|
simplejson==3.17.0
|
||||||
six==1.14.0
|
six==1.14.0
|
||||||
@ -91,3 +91,5 @@ webencodings==0.5.1
|
|||||||
websocket-client==0.57.0
|
websocket-client==0.57.0
|
||||||
Werkzeug==1.0.1
|
Werkzeug==1.0.1
|
||||||
xmltodict==0.12.0
|
xmltodict==0.12.0
|
||||||
|
|
||||||
|
einops~=0.3.0
|
@ -1,26 +0,0 @@
|
|||||||
from ml_lib.utils.config import Config
|
|
||||||
from models.conv_classifier import ConvClassifier
|
|
||||||
from models.bandwise_conv_classifier import BandwiseConvClassifier
|
|
||||||
from models.bandwise_conv_multihead_classifier import BandwiseConvMultiheadClassifier
|
|
||||||
from models.ensemble import Ensemble
|
|
||||||
from models.residual_conv_classifier import ResidualConvClassifier
|
|
||||||
from models.transformer_model import VisualTransformer
|
|
||||||
|
|
||||||
|
|
||||||
class MConfig(Config):
|
|
||||||
# TODO: There should be a way to automate this.
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _model_map(self):
|
|
||||||
return dict(ConvClassifier=ConvClassifier,
|
|
||||||
CC=ConvClassifier,
|
|
||||||
BandwiseConvClassifier=BandwiseConvClassifier,
|
|
||||||
BCC=BandwiseConvClassifier,
|
|
||||||
BandwiseConvMultiheadClassifier=BandwiseConvMultiheadClassifier,
|
|
||||||
BCMC=BandwiseConvMultiheadClassifier,
|
|
||||||
Ensemble=Ensemble,
|
|
||||||
E=Ensemble,
|
|
||||||
ResidualConvClassifier=ResidualConvClassifier,
|
|
||||||
RCC=ResidualConvClassifier,
|
|
||||||
ViT=VisualTransformer
|
|
||||||
)
|
|
@ -8,7 +8,8 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from torchcontrib.optim import SWA
|
from torchcontrib.optim import SWA
|
||||||
from torchvision.transforms import Compose, RandomApply
|
from torchvision.transforms import Compose, RandomApply
|
||||||
|
|
||||||
@ -25,10 +26,23 @@ class BaseOptimizerMixin:
|
|||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
optimizer_dict = dict(
|
||||||
|
# 'optimizer':optimizer, # The Optimizer
|
||||||
|
# 'lr_scheduler': scheduler, # The LR scheduler
|
||||||
|
frequency=1, # The frequency of the scheduler
|
||||||
|
interval='epoch', # The unit of the scheduler's step size
|
||||||
|
# 'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
|
||||||
|
# 'monitor': 'mean_val_loss' # Metric to monitor
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
|
||||||
if self.params.sto_weight_avg:
|
if self.params.sto_weight_avg:
|
||||||
opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
|
optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)
|
||||||
return opt
|
optimizer_dict.update(optimizer=optimizer)
|
||||||
|
if self.params.lr_warmup_steps:
|
||||||
|
scheduler = CosineAnnealingWarmRestarts(optimizer, self.params.lr_warmup_steps)
|
||||||
|
optimizer_dict.update(lr_scheduler=scheduler)
|
||||||
|
return optimizer_dict
|
||||||
|
|
||||||
def on_train_end(self):
|
def on_train_end(self):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
@ -54,17 +68,18 @@ class BaseTrainMixin:
|
|||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
batch_x, batch_y = batch_xy
|
batch_x, batch_y = batch_xy
|
||||||
y = self(batch_x).main_out
|
y = self(batch_x).main_out
|
||||||
bce_loss = self.bce_loss(y, batch_y)
|
bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||||
return dict(loss=bce_loss)
|
return dict(loss=bce_loss)
|
||||||
|
|
||||||
def training_epoch_end(self, outputs):
|
def training_epoch_end(self, outputs):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
keys = list(outputs[0].keys())
|
keys = list(outputs[0].keys())
|
||||||
|
|
||||||
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
|
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||||
for output in outputs]))
|
for output in outputs]))
|
||||||
for key in keys if 'loss' in key})
|
for key in keys if 'loss' in key}
|
||||||
return summary_dict
|
for key in summary_dict.keys():
|
||||||
|
self.log(key, summary_dict[key])
|
||||||
|
|
||||||
|
|
||||||
class BaseValMixin:
|
class BaseValMixin:
|
||||||
@ -77,17 +92,17 @@ class BaseValMixin:
|
|||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
batch_x, batch_y = batch_xy
|
batch_x, batch_y = batch_xy
|
||||||
y = self(batch_x).main_out
|
y = self(batch_x).main_out
|
||||||
val_bce_loss = self.bce_loss(y, batch_y)
|
val_bce_loss = self.bce_loss(y.squeeze(), batch_y)
|
||||||
return dict(val_bce_loss=val_bce_loss,
|
return dict(val_bce_loss=val_bce_loss,
|
||||||
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
batch_idx=batch_idx, y=y, batch_y=batch_y)
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs, *args, **kwargs):
|
def validation_epoch_end(self, outputs, *_, **__):
|
||||||
assert isinstance(self, LightningBaseModule)
|
assert isinstance(self, LightningBaseModule)
|
||||||
summary_dict = dict(log=dict())
|
summary_dict = dict()
|
||||||
for output_idx, output in enumerate(outputs):
|
for output_idx, output in enumerate(outputs):
|
||||||
keys = list(output[0].keys())
|
keys = list(output[0].keys())
|
||||||
ident = '' if output_idx == 0 else '_train'
|
ident = '' if output_idx == 0 else '_train'
|
||||||
summary_dict['log'].update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
summary_dict.update({f'mean{ident}_{key}': torch.mean(torch.stack([output[key]
|
||||||
for output in output]))
|
for output in output]))
|
||||||
for key in keys if 'loss' in key}
|
for key in keys if 'loss' in key}
|
||||||
)
|
)
|
||||||
@ -101,8 +116,9 @@ class BaseValMixin:
|
|||||||
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
||||||
sample_weight=None, zero_division='warn')
|
sample_weight=None, zero_division='warn')
|
||||||
uar_score = torch.as_tensor(uar_score)
|
uar_score = torch.as_tensor(uar_score)
|
||||||
summary_dict['log'].update({f'uar{ident}_score': uar_score})
|
summary_dict.update({f'uar{ident}_score': uar_score})
|
||||||
return summary_dict
|
for key in summary_dict.keys():
|
||||||
|
self.log(key, summary_dict[key])
|
||||||
|
|
||||||
|
|
||||||
class BinaryMaskDatasetMixin:
|
class BinaryMaskDatasetMixin:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user