module_mixins.py updated with tensor return

This commit is contained in:
Si11ium 2020-05-21 12:33:36 +02:00
parent 5292b6d986
commit b529d130df
2 changed files with 2 additions and 2 deletions

View File

@ -46,7 +46,7 @@ if __name__ == '__main__':
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
data_stretch=True, train_epochs=101) data_stretch=True, train_epochs=101)
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]: for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
arg_dict.update(dicts) arg_dict.update(dicts)
config = config.update(arg_dict) config = config.update(arg_dict)

View File

@ -100,7 +100,7 @@ class BaseValMixin:
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro', uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
sample_weight=None, zero_division='warn') sample_weight=None, zero_division='warn')
uar_score = torch.as_tensor(uar_score)
summary_dict['log'].update({f'uar{ident}_score': uar_score}) summary_dict['log'].update({f'uar{ident}_score': uar_score})
return summary_dict return summary_dict