78 lines
3.6 KiB
Python
78 lines
3.6 KiB
Python
# Imports
|
|
# =============================================================================
|
|
from _paramters import main_arg_parser
|
|
from main import run_lightning_loop
|
|
|
|
import warnings
|
|
import shutil
|
|
|
|
from ml_lib.utils.config import Config
|
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = main_arg_parser.parse_args()
|
|
# Model Settings
|
|
config = Config().read_namespace(args)
|
|
|
|
arg_dict = dict()
|
|
for seed in range(1):
|
|
arg_dict.update(main_seed=seed)
|
|
if False:
|
|
for patch_size in [3, 5 , 9]:
|
|
for model in ['VerticalVisualTransformer']:
|
|
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
|
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
|
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
|
data_stretch=False, train_epochs=401)
|
|
|
|
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
|
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
|
data_stretch=True, train_epochs=101)
|
|
|
|
speed_conf = raw_conf.copy()
|
|
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
|
data_stretch=True, train_epochs=101)
|
|
|
|
mask_conf = raw_conf.copy()
|
|
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
|
|
|
|
noise_conf = raw_conf.copy()
|
|
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
|
|
|
|
shift_conf = raw_conf.copy()
|
|
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
|
|
|
|
loudness_conf = raw_conf.copy()
|
|
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
|
|
|
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
|
|
|
arg_dict.update(dicts)
|
|
if True:
|
|
for patch_size in [7]:
|
|
for lat_dim in [32]:
|
|
for heads in [8]:
|
|
for embedding_size in [7**2]:
|
|
for attn_depth in [1, 3, 5, 7]:
|
|
for model in ['HorizontalVisualTransformer']:
|
|
arg_dict.update(
|
|
model_type=model,
|
|
model_patch_size=patch_size,
|
|
model_lat_dim=lat_dim,
|
|
model_heads=heads,
|
|
model_embedding_size=embedding_size,
|
|
model_attn_depth=attn_depth
|
|
)
|
|
config = config.update(arg_dict)
|
|
version_path = config.exp_path / config.version
|
|
if version_path.exists():
|
|
if not (version_path / 'weights.ckpt').exists():
|
|
shutil.rmtree(version_path)
|
|
else:
|
|
continue
|
|
run_lightning_loop(config)
|