From a079a196affbb9c75df7b3ca8e53d34ab396ae96 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Sun, 22 Nov 2020 16:24:00 +0100 Subject: [PATCH] New Model, Many Changes --- _paramters.py | 2 +- models/transformer_model_sequential.py | 19 +++----- multi_run.py | 61 +++++++++++++------------- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/_paramters.py b/_paramters.py index db30f6f..5962bc7 100644 --- a/_paramters.py +++ b/_paramters.py @@ -52,7 +52,7 @@ main_arg_parser.add_argument("--model_features", type=int, default=64, help="") main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="") # Transformer Specific -main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="") +main_arg_parser.add_argument("--model_patch_size", type=int, default=3, help="") main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="") main_arg_parser.add_argument("--model_heads", type=int, default=8, help="") main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="") diff --git a/models/transformer_model_sequential.py b/models/transformer_model_sequential.py index 54ec1c7..2a9bc97 100644 --- a/models/transformer_model_sequential.py +++ b/models/transformer_model_sequential.py @@ -5,8 +5,6 @@ import warnings import torch from torch import nn -from einops import repeat - from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, @@ -39,16 +37,13 @@ class SequentialVisualTransformer(BinaryMaskDatasetMixin, self.embed_dim = self.params.embedding_size self.patch_size = self.params.patch_size self.height = height + self.width = width + self.channels = channels - # Automatic Image Shaping - image_size = (max(height, width) // self.patch_size) * self.patch_size - self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size + self.new_width = ((self.width - self.patch_size)//1) + 1 - # This should be obsolete - assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size' - - num_patches = (self.image_size // self.patch_size) ** 2 - patch_dim = channels * self.patch_size * self.image_size + num_patches = self.new_width - (self.patch_size // 2) + patch_dim = channels * self.patch_size * self.height assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ f'attention. Try decreasing your patch size' @@ -61,9 +56,9 @@ class SequentialVisualTransformer(BinaryMaskDatasetMixin, func(message) # Utility Modules - self.autopad = AutoPadToShape((self.image_size, self.image_size)) + self.autopad = AutoPadToShape((self.height, self.new_width)) self.dropout = nn.Dropout(self.params.dropout) - self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False) + self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.patch_size), keepdim=False) # Modules with Parameters self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim, diff --git a/multi_run.py b/multi_run.py index 22f8e78..bce16fd 100644 --- a/multi_run.py +++ b/multi_run.py @@ -20,42 +20,43 @@ if __name__ == '__main__': config = Config().read_namespace(args) arg_dict = dict() - for seed in range(0, 10): + for seed in range(0, 3): arg_dict.update(main_seed=seed) - for model in ['VisualTransformer']: - arg_dict.update(model_type=model) - raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, - data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, - data_stretch=False, train_epochs=401) + for patch_size in [3, 5 , 9]: + for model in ['SequentialVisualTransformer']: + arg_dict.update(model_type=model, model_patch_size=patch_size) + raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, + data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, + data_stretch=False, train_epochs=401) - all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, - data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, - data_stretch=True, train_epochs=101) + all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, + data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, + data_stretch=True, train_epochs=101) - speed_conf = raw_conf.copy() - speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, - data_stretch=True, train_epochs=101) + speed_conf = raw_conf.copy() + speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, + data_stretch=True, train_epochs=101) - mask_conf = raw_conf.copy() - mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) + mask_conf = raw_conf.copy() + mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) - noise_conf = raw_conf.copy() - noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) + noise_conf = raw_conf.copy() + noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) - shift_conf = raw_conf.copy() - shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) + shift_conf = raw_conf.copy() + shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) - loudness_conf = raw_conf.copy() - loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) + loudness_conf = raw_conf.copy() + loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) - for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: + for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: - arg_dict.update(dicts) - config = config.update(arg_dict) - version_path = config.exp_path / config.version - if version_path.exists(): - if not (version_path / 'weights.ckpt').exists(): - shutil.rmtree(version_path) - else: - continue - run_lightning_loop(config) + arg_dict.update(dicts) + config = config.update(arg_dict) + version_path = config.exp_path / config.version + if version_path.exists(): + if not (version_path / 'weights.ckpt').exists(): + shutil.rmtree(version_path) + else: + continue + run_lightning_loop(config)