Python 3.8 branch merged

some small template fixes
This commit is contained in:
Steffen Illium
2020-05-17 22:11:21 +02:00
parent fc93f71608
commit e423d6fe31
2 changed files with 6 additions and 12 deletions

View File

@ -44,11 +44,6 @@ def run_lightning_loop(config_obj):
# Init # Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters) model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights(torch.nn.init.xavier_normal_) model.init_weights(torch.nn.init.xavier_normal_)
if model.name == 'CNNRouteGeneratorDiscriminated':
# ToDo: Make this dependent on the used seed
path = logger.outpath / 'classifier_cnn' / 'version_0'
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer # Trainer
# ============================================================================= # =============================================================================
@ -70,8 +65,8 @@ def run_lightning_loop(config_obj):
trainer.fit(model) trainer.fit(model)
# Save the last state & all parameters # Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt') trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir) model.save_to_disk(config_obj.exp_path)
# Evaluate It # Evaluate It
if config_obj.main.eval: if config_obj.main.eval:

View File

@ -1,6 +1,6 @@
import warnings import warnings
from utils.config import Config from _templates.new_project.utils.project_config import Config
warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
@ -15,10 +15,9 @@ if __name__ == '__main__':
# Model Settings # Model Settings
config = Config().read_namespace(args) config = Config().read_namespace(args)
# bias, activation, model, norm, max_epochs, filters # bias, activation, model, norm, max_epochs
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu', cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, data_batchsize=512)
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512) # bias, activation, model, norm, max_epochs
# bias, activation, model, norm, max_epochs, sr, feature_mixed_dim, filters
for arg_dict in [cnn_classifier]: for arg_dict in [cnn_classifier]:
for seed in range(5): for seed in range(5):