diff --git a/_parameters.py b/_parameters.py index 751b801..16dcfee 100644 --- a/_parameters.py +++ b/_parameters.py @@ -27,8 +27,8 @@ main_arg_parser.add_argument("--data_additional_resource_root", type=str, defaul main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") # Transformations -main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") -main_arg_parser.add_argument("--transformations_normalize", type=strtobool, default=False, help="") +# main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") +# main_arg_parser.add_argument("--transformations_normalize", type=strtobool, default=False, help="") # Transformations main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") diff --git a/main.py b/main.py index 1c7b357..9b1010b 100644 --- a/main.py +++ b/main.py @@ -7,10 +7,10 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from modules.utils import LightningBaseModule -from utils.config import Config -from utils.logging import Logger -from utils.model_io import SavedLightningModels +from ml_lib.modules.util import LightningBaseModule +from ml_lib.utils.config import Config +from ml_lib.utils.logging import Logger +from ml_lib.utils.model_io import SavedLightningModels warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) @@ -76,6 +76,6 @@ def run_lightning_loop(config_obj): if __name__ == "__main__": - from _templates.new_project._parameters import args + from ._parameters import args config = Config.read_namespace(args) trained_model = run_lightning_loop(config)