module_mixins.py updated with tensor return
This commit is contained in:
parent
5292b6d986
commit
b529d130df
@ -46,7 +46,7 @@ if __name__ == '__main__':
|
|||||||
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
|
||||||
data_stretch=True, train_epochs=101)
|
data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
||||||
|
|
||||||
arg_dict.update(dicts)
|
arg_dict.update(dicts)
|
||||||
config = config.update(arg_dict)
|
config = config.update(arg_dict)
|
||||||
|
@ -100,7 +100,7 @@ class BaseValMixin:
|
|||||||
|
|
||||||
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
uar_score = sklearn.metrics.recall_score(y_true, y_pred, labels=[0, 1], average='macro',
|
||||||
sample_weight=None, zero_division='warn')
|
sample_weight=None, zero_division='warn')
|
||||||
|
uar_score = torch.as_tensor(uar_score)
|
||||||
summary_dict['log'].update({f'uar{ident}_score': uar_score})
|
summary_dict['log'].update({f'uar{ident}_score': uar_score})
|
||||||
return summary_dict
|
return summary_dict
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user