CCS intergration training running
notebooks
This commit is contained in:
parent
d3e7bf7efb
commit
6816e423ff
@ -61,6 +61,9 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
|||||||
for module in [Logger, Trainer, found_data_class, found_model_class]:
|
for module in [Logger, Trainer, found_data_class, found_model_class]:
|
||||||
parser = module.add_argparse_args(parser)
|
parser = module.add_argparse_args(parser)
|
||||||
|
|
||||||
|
# This is obsolete
|
||||||
|
# new_defaults.update(data_name=data_name, model_name=model_name)
|
||||||
|
|
||||||
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
||||||
|
|
||||||
args = vars(args)
|
args = vars(args)
|
||||||
|
@ -54,8 +54,12 @@ def check_path(file_path):
|
|||||||
|
|
||||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||||
"""Locate an object by name or dotted path, importing as necessary."""
|
"""Locate an object by name or dotted path, importing as necessary."""
|
||||||
|
import sys
|
||||||
|
sys.path.append("..")
|
||||||
folder_path = Path(folder_path)
|
folder_path = Path(folder_path)
|
||||||
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||||
|
# possible_package_path = folder_path / '__init__.py'
|
||||||
|
# package = str(possible_package_path) if possible_package_path.exists() else None
|
||||||
for module_path in module_paths:
|
for module_path in module_paths:
|
||||||
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
||||||
try:
|
try:
|
||||||
|
@ -1,12 +1,38 @@
|
|||||||
try:
|
try:
|
||||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
except ImportError: # pragma: no-cover
|
except ImportError: # pragma: no-cover
|
||||||
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
raise ImportError('You want to use `matplotlib` which is not installed yet,' # pragma: no-cover
|
||||||
' install it with `pip install matplotlib`.')
|
' install it with `pip install matplotlib`.')
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def prettyfy_sns():
|
||||||
|
plt.style.use('default')
|
||||||
|
try:
|
||||||
|
import seaborn as sns
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError('You want to use `seaborn` which is not installed yet,' # pragma: no-cover
|
||||||
|
' install it with `pip install seaborn`.')
|
||||||
|
sns.set_palette('Dark2')
|
||||||
|
|
||||||
|
tex_fonts = {
|
||||||
|
# Use LaTeX to write all text
|
||||||
|
"text.usetex": True,
|
||||||
|
"font.family": "serif",
|
||||||
|
# Use 10pt font in plots, to match 10pt font in document
|
||||||
|
"axes.labelsize": 10,
|
||||||
|
"font.size": 10,
|
||||||
|
# Make the legend/label fonts a little smaller
|
||||||
|
"legend.fontsize": 8,
|
||||||
|
"xtick.labelsize": 8,
|
||||||
|
"ytick.labelsize": 8
|
||||||
|
}
|
||||||
|
|
||||||
|
plt.rcParams.update(tex_fonts)
|
||||||
|
|
||||||
|
|
||||||
class Plotter(object):
|
class Plotter(object):
|
||||||
|
|
||||||
def __init__(self, root_path=''):
|
def __init__(self, root_path=''):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user