paper preperations and notebooks, optuna callbacks
This commit is contained in:
@ -52,6 +52,10 @@ class NormalizeLocal(object):
|
||||
return f'{self.__class__.__name__}({self.__dict__})'
|
||||
|
||||
def __call__(self, x: np.ndarray):
|
||||
|
||||
x[np.isnan(x)] = 0
|
||||
x[np.isinf(x)] = 0
|
||||
|
||||
mean = x.mean()
|
||||
std = x.std() + 0.0001
|
||||
|
||||
|
@ -190,10 +190,6 @@ class BaseCNNEncoder(ShapeMixin, nn.Module):
|
||||
kernels = kernels if not isinstance(kernels, int) else [kernels] * len(filters)
|
||||
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||
|
||||
# Optional Padding for odd image-sizes
|
||||
# Obsolet, cdan be done by autopadding module on incoming tensors
|
||||
# in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
|
||||
|
||||
# Parameters
|
||||
self.lat_dim = lat_dim
|
||||
self.in_shape = in_shape
|
||||
|
@ -14,6 +14,8 @@ from sklearn.metrics import ConfusionMatrixDisplay
|
||||
|
||||
# Utility - Modules
|
||||
###################
|
||||
from ..metrics.binary_class_classifictaion import BinaryScores
|
||||
from ..metrics.multi_class_classification import MultiClassScores
|
||||
from ..utils.model_io import ModelParameters
|
||||
from ..utils.tools import add_argparse_args
|
||||
|
||||
@ -133,9 +135,6 @@ try:
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_to_disk(self, model_path):
|
||||
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
|
||||
if not (model_path / 'model_class.obj').exists():
|
||||
@ -174,6 +173,12 @@ try:
|
||||
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
|
||||
self.apply(weight_initializer)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
||||
module_types = (LightningBaseModule, nn.Module,)
|
||||
|
||||
except ImportError:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
from pytorch_lightning import Callback, Trainer, LightningModule
|
||||
|
||||
|
||||
@ -17,6 +18,10 @@ class BestScoresCallback(Callback):
|
||||
current_score = trainer.callback_metrics.get(monitor)
|
||||
if current_score is None:
|
||||
pass
|
||||
elif torch.isinf(current_score):
|
||||
pass
|
||||
elif torch.isnan(current_score):
|
||||
pass
|
||||
else:
|
||||
self.best_scores[monitor] = max(self.best_scores[monitor], current_score)
|
||||
if self.best_scores[monitor] == current_score:
|
||||
|
@ -37,12 +37,6 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||
defaults = config[key]
|
||||
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||
|
||||
if new_defaults['debug']:
|
||||
new_defaults.update(
|
||||
max_epochs=2,
|
||||
max_steps=2 # The seems to be the new "fast_dev_run"
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
overrides = overrides or dict()
|
||||
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
||||
@ -71,13 +65,20 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
|
||||
row_log_interval=1000, # TODO: Better Value / Setting
|
||||
log_save_interval=10000, # TODO: Better Value / Setting
|
||||
auto_lr_find=not args['debug'],
|
||||
weights_summary='top',
|
||||
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1),
|
||||
)
|
||||
|
||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||
args.update(**overrides)
|
||||
if args['debug']:
|
||||
args.update(
|
||||
# The seems to be the new "fast_dev_run"
|
||||
val_check_interval=1,
|
||||
max_epochs=2,
|
||||
max_steps=2,
|
||||
auto_lr_find=False,
|
||||
check_val_every_n_epoch=1
|
||||
)
|
||||
return args, found_data_class, found_model_class, found_seed
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user