New Model, Many Changes
This commit is contained in:
parent
be097a111a
commit
a079a196af
@ -52,7 +52,7 @@ main_arg_parser.add_argument("--model_features", type=int, default=64, help="")
|
|||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 64, 128, 64]", help="")
|
||||||
|
|
||||||
# Transformer Specific
|
# Transformer Specific
|
||||||
main_arg_parser.add_argument("--model_patch_size", type=int, default=9, help="")
|
main_arg_parser.add_argument("--model_patch_size", type=int, default=3, help="")
|
||||||
main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="")
|
main_arg_parser.add_argument("--model_attn_depth", type=int, default=3, help="")
|
||||||
main_arg_parser.add_argument("--model_heads", type=int, default=8, help="")
|
main_arg_parser.add_argument("--model_heads", type=int, default=8, help="")
|
||||||
main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="")
|
main_arg_parser.add_argument("--model_embedding_size", type=int, default=64, help="")
|
||||||
|
@ -5,8 +5,6 @@ import warnings
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
|
|
||||||
from ml_lib.modules.blocks import TransformerModule
|
from ml_lib.modules.blocks import TransformerModule
|
||||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow)
|
||||||
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin,
|
||||||
@ -39,16 +37,13 @@ class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
|||||||
self.embed_dim = self.params.embedding_size
|
self.embed_dim = self.params.embedding_size
|
||||||
self.patch_size = self.params.patch_size
|
self.patch_size = self.params.patch_size
|
||||||
self.height = height
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
# Automatic Image Shaping
|
self.new_width = ((self.width - self.patch_size)//1) + 1
|
||||||
image_size = (max(height, width) // self.patch_size) * self.patch_size
|
|
||||||
self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
|
|
||||||
|
|
||||||
# This should be obsolete
|
num_patches = self.new_width - (self.patch_size // 2)
|
||||||
assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
|
patch_dim = channels * self.patch_size * self.height
|
||||||
|
|
||||||
num_patches = (self.image_size // self.patch_size) ** 2
|
|
||||||
patch_dim = channels * self.patch_size * self.image_size
|
|
||||||
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
|
||||||
f'attention. Try decreasing your patch size'
|
f'attention. Try decreasing your patch size'
|
||||||
|
|
||||||
@ -61,9 +56,9 @@ class SequentialVisualTransformer(BinaryMaskDatasetMixin,
|
|||||||
func(message)
|
func(message)
|
||||||
|
|
||||||
# Utility Modules
|
# Utility Modules
|
||||||
self.autopad = AutoPadToShape((self.image_size, self.image_size))
|
self.autopad = AutoPadToShape((self.height, self.new_width))
|
||||||
self.dropout = nn.Dropout(self.params.dropout)
|
self.dropout = nn.Dropout(self.params.dropout)
|
||||||
self.slider = SlidingWindow((self.image_size, self.patch_size), keepdim=False)
|
self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.patch_size), keepdim=False)
|
||||||
|
|
||||||
# Modules with Parameters
|
# Modules with Parameters
|
||||||
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim,
|
||||||
|
61
multi_run.py
61
multi_run.py
@ -20,42 +20,43 @@ if __name__ == '__main__':
|
|||||||
config = Config().read_namespace(args)
|
config = Config().read_namespace(args)
|
||||||
|
|
||||||
arg_dict = dict()
|
arg_dict = dict()
|
||||||
for seed in range(0, 10):
|
for seed in range(0, 3):
|
||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
for model in ['VisualTransformer']:
|
for patch_size in [3, 5 , 9]:
|
||||||
arg_dict.update(model_type=model)
|
for model in ['SequentialVisualTransformer']:
|
||||||
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
||||||
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
||||||
data_stretch=False, train_epochs=401)
|
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
|
data_stretch=False, train_epochs=401)
|
||||||
|
|
||||||
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
||||||
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
||||||
data_stretch=True, train_epochs=101)
|
data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
speed_conf = raw_conf.copy()
|
speed_conf = raw_conf.copy()
|
||||||
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
||||||
data_stretch=True, train_epochs=101)
|
data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
mask_conf = raw_conf.copy()
|
mask_conf = raw_conf.copy()
|
||||||
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
|
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
noise_conf = raw_conf.copy()
|
noise_conf = raw_conf.copy()
|
||||||
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
|
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
shift_conf = raw_conf.copy()
|
shift_conf = raw_conf.copy()
|
||||||
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
|
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
loudness_conf = raw_conf.copy()
|
loudness_conf = raw_conf.copy()
|
||||||
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
|
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
||||||
|
|
||||||
arg_dict.update(dicts)
|
arg_dict.update(dicts)
|
||||||
config = config.update(arg_dict)
|
config = config.update(arg_dict)
|
||||||
version_path = config.exp_path / config.version
|
version_path = config.exp_path / config.version
|
||||||
if version_path.exists():
|
if version_path.exists():
|
||||||
if not (version_path / 'weights.ckpt').exists():
|
if not (version_path / 'weights.ckpt').exists():
|
||||||
shutil.rmtree(version_path)
|
shutil.rmtree(version_path)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
run_lightning_loop(config)
|
run_lightning_loop(config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user