Small bugfixes
This commit is contained in:
parent
fecf4923c2
commit
2c9cb2e94a
10
main.py
10
main.py
@ -10,12 +10,16 @@ from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
from ml_lib.utils.loggers import Logger
|
||||
|
||||
import variables as v
|
||||
from ml_lib.utils.tools import fix_all_random_seeds
|
||||
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
|
||||
def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=None):
|
||||
def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_callbacks=None):
|
||||
|
||||
fix_all_random_seeds(seed)
|
||||
|
||||
with Logger.from_argparse_args(h_params) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
@ -79,13 +83,13 @@ def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=N
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
|
||||
cmd_args, found_data_class, found_model_class, found_seed = parse_comandline_args_add_defaults('_parameters.ini')
|
||||
|
||||
# To NameSpace
|
||||
hparams = Namespace(**cmd_args)
|
||||
|
||||
# Start
|
||||
# -----------------
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class)
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)
|
||||
print('done')
|
||||
pass
|
||||
|
@ -4,7 +4,6 @@ from argparse import Namespace
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from performer_pytorch import Performer
|
||||
from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@ -15,6 +14,8 @@ from util.module_mixins import CombinedModelMixins
|
||||
|
||||
MIN_NUM_PATCHES = 16
|
||||
|
||||
try:
|
||||
from performer_pytorch import Performer
|
||||
|
||||
class VisualPerformer(CombinedModelMixins,
|
||||
LightningBaseModule
|
||||
@ -131,3 +132,7 @@ class VisualPerformer(CombinedModelMixins,
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
except ImportError: # pragma: do not provide model class
|
||||
print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install performer_pytorch`.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user