CCS intergration training running

notebooks
This commit is contained in:
Steffen 2021-03-24 08:03:11 +01:00
parent d3e7bf7efb
commit 6816e423ff
3 changed files with 34 additions and 1 deletions

View File

@ -61,6 +61,9 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
for module in [Logger, Trainer, found_data_class, found_model_class]:
parser = module.add_argparse_args(parser)
# This is obsolete
# new_defaults.update(data_name=data_name, model_name=model_name)
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
args = vars(args)

View File

@ -54,8 +54,12 @@ def check_path(file_path):
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""Locate an object by name or dotted path, importing as necessary."""
import sys
sys.path.append("..")
folder_path = Path(folder_path)
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
# possible_package_path = folder_path / '__init__.py'
# package = str(possible_package_path) if possible_package_path.exists() else None
for module_path in module_paths:
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try:

View File

@ -1,12 +1,38 @@
try:
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib import pyplot as plt
except ImportError: # pragma: no-cover
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
raise ImportError('You want to use `matplotlib` which is not installed yet,' # pragma: no-cover
' install it with `pip install matplotlib`.')
from pathlib import Path
def prettyfy_sns():
plt.style.use('default')
try:
import seaborn as sns
except ImportError:
raise ImportError('You want to use `seaborn` which is not installed yet,' # pragma: no-cover
' install it with `pip install seaborn`.')
sns.set_palette('Dark2')
tex_fonts = {
# Use LaTeX to write all text
"text.usetex": True,
"font.family": "serif",
# Use 10pt font in plots, to match 10pt font in document
"axes.labelsize": 10,
"font.size": 10,
# Make the legend/label fonts a little smaller
"legend.fontsize": 8,
"xtick.labelsize": 8,
"ytick.labelsize": 8
}
plt.rcParams.update(tex_fonts)
class Plotter(object):
def __init__(self, root_path=''):