steffen 21ff075626 Experiment parameters
(cherry picked from commit 98ccecac04c203ddfe2f01f10cf5e3e4031509ed)
2020-11-23 08:48:10 +01:00

77 lines
3.6 KiB
Python

# Imports
# =============================================================================
from _paramters import main_arg_parser
from main import run_lightning_loop
import warnings
from ml_lib.utils.config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
if __name__ == '__main__':
args = main_arg_parser.parse_args()
# Model Settings
config = Config().read_namespace(args)
arg_dict = dict()
for seed in range(1):
arg_dict.update(main_seed=seed)
if False:
for patch_size in [3, 5 , 9]:
for model in ['SequentialVisualTransformer']:
arg_dict.update(model_type=model, model_patch_size=patch_size)
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=False, train_epochs=401)
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
data_stretch=True, train_epochs=101)
speed_conf = raw_conf.copy()
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_stretch=True, train_epochs=101)
mask_conf = raw_conf.copy()
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
noise_conf = raw_conf.copy()
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
shift_conf = raw_conf.copy()
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
loudness_conf = raw_conf.copy()
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
arg_dict.update(dicts)
if True:
for patch_size in [3, 7]:
for lat_dim in [4, 32]:
for heads in [2, 4]:
for embedding_size in [32, 64]:
for attn_depth in [1, 3]:
for model in ['SequentialVisualTransformer', 'VisualTransformer']:
arg_dict.update(
model_type=model,
model_patch_size=patch_size,
model_lat_dim=lat_dim,
model_heads=heads,
model_embedding_size=embedding_size,
model_attn_depth=attn_depth
)
config = config.update(arg_dict)
version_path = config.exp_path / config.version
if version_path.exists():
if not (version_path / 'weights.ckpt').exists():
shutil.rmtree(version_path)
else:
continue
run_lightning_loop(config)