New Model, Many Changes
This commit is contained in:
@ -61,10 +61,11 @@ class BaseTrainMixin:
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
keys = list(outputs[0].keys())
|
||||
|
||||
summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key})
|
||||
return summary_dict
|
||||
summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BaseValMixin:
|
||||
@ -83,16 +84,16 @@ class BaseValMixin:
|
||||
|
||||
def validation_epoch_end(self, outputs, *_, **__):
|
||||
assert isinstance(self, LightningBaseModule)
|
||||
summary_dict = dict(log=dict())
|
||||
summary_dict = dict()
|
||||
# In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
|
||||
# for output_idx, output in enumerate(outputs):
|
||||
# else:list[dict[]]
|
||||
keys = list(outputs.keys())
|
||||
# Add Every Value das has a "loss" in it, by calc. mean over all occurences.
|
||||
summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
|
||||
for output in outputs]))
|
||||
for key in keys if 'loss' in key}
|
||||
)
|
||||
"""
|
||||
# Additional Score like the unweighted Average Recall:
|
||||
# UnweightedAverageRecall
|
||||
@ -107,7 +108,8 @@ class BaseValMixin:
|
||||
summary_dict['log'].update({f'uar_score': uar_score})
|
||||
"""
|
||||
|
||||
return summary_dict
|
||||
for key in summary_dict.keys():
|
||||
self.log(key, summary_dict[key])
|
||||
|
||||
|
||||
class BinaryMaskDatasetMixin:
|
||||
|
@ -1,8 +1,5 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from ml_lib.utils.config import Config
|
||||
|
||||
|
||||
class GlobalVar(Namespace):
|
||||
# Labels for classes
|
||||
LEFT = 1
|
||||
@ -21,10 +18,3 @@ class GlobalVar(Namespace):
|
||||
train='train',
|
||||
vali='vali',
|
||||
test='test'
|
||||
|
||||
|
||||
class ThisConfig(Config):
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
return dict()
|
||||
|
Reference in New Issue
Block a user