New Model, Many Changes

This commit is contained in:
Si11ium
2020-11-21 09:28:25 +01:00
parent 13812b83b5
commit 14ed4e0117
8 changed files with 127 additions and 102 deletions

View File

@ -61,10 +61,11 @@ class BaseTrainMixin:
assert isinstance(self, LightningBaseModule)
keys = list(outputs[0].keys())
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key})
return summary_dict
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BaseValMixin:
@ -83,16 +84,16 @@ class BaseValMixin:
def validation_epoch_end(self, outputs, *_, **__):
assert isinstance(self, LightningBaseModule)
summary_dict = dict(log=dict())
summary_dict = dict()
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
# for output_idx, output in enumerate(outputs):
# else:list[dict[]]
keys = list(outputs.keys())
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
for output in outputs]))
for key in keys if 'loss' in key}
)
"""
# Additional Score like the unweighted Average Recall:
# UnweightedAverageRecall
@ -107,7 +108,8 @@ class BaseValMixin:
summary_dict['log'].update({f'uar_score': uar_score})
"""
return summary_dict
for key in summary_dict.keys():
self.log(key, summary_dict[key])
class BinaryMaskDatasetMixin:

View File

@ -1,8 +1,5 @@
from argparse import Namespace
from ml_lib.utils.config import Config
class GlobalVar(Namespace):
# Labels for classes
LEFT = 1
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
train='train',
vali='vali',
test='test'
class ThisConfig(Config):
@property
def _model_map(self):
return dict()