Small bugfixes

This commit is contained in:
Steffen 2021-03-18 12:12:43 +01:00
parent fecf4923c2
commit 2c9cb2e94a
2 changed files with 108 additions and 99 deletions

10
main.py
View File

@ -10,12 +10,16 @@ from ml_lib.utils.config import parse_comandline_args_add_defaults
from ml_lib.utils.loggers import Logger
import variables as v
from ml_lib.utils.tools import fix_all_random_seeds
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=None):
def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_callbacks=None):
fix_all_random_seeds(seed)
with Logger.from_argparse_args(h_params) as logger:
# Callbacks
# =============================================================================
@ -79,13 +83,13 @@ def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=N
if __name__ == '__main__':
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
cmd_args, found_data_class, found_model_class, found_seed = parse_comandline_args_add_defaults('_parameters.ini')
# To NameSpace
hparams = Namespace(**cmd_args)
# Start
# -----------------
run_lightning_loop(hparams, found_data_class, found_model_class)
run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)
print('done')
pass

View File

@ -4,7 +4,6 @@ from argparse import Namespace
import warnings
import torch
from performer_pytorch import Performer
from torch import nn
from einops import rearrange, repeat
@ -15,8 +14,10 @@ from util.module_mixins import CombinedModelMixins
MIN_NUM_PATCHES = 16
try:
from performer_pytorch import Performer
class VisualPerformer(CombinedModelMixins,
class VisualPerformer(CombinedModelMixins,
LightningBaseModule
):
@ -131,3 +132,7 @@ class VisualPerformer(CombinedModelMixins,
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)
except ImportError: # pragma: do not provide model class
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
' install it with `pip install performer_pytorch`.')