diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index fe7683489..05b0baadc 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -127,4 +127,4 @@ def _compute(self, predictions, references, labels=None, pos_label=1, average="b score = f1_score( references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight ) - return {"f1": float(score) if score.size == 1 else score} + return {"f1": score if getattr(score, "size", 1) > 1 else float(score)} diff --git a/metrics/precision/precision.py b/metrics/precision/precision.py index 4b35aa7e4..170d0e5dd 100644 --- a/metrics/precision/precision.py +++ b/metrics/precision/precision.py @@ -142,4 +142,4 @@ def _compute( sample_weight=sample_weight, zero_division=zero_division, ) - return {"precision": float(score) if score.size == 1 else score} + return {"precision": score if getattr(score, "size", 1) > 1 else float(score)} diff --git a/metrics/recall/recall.py b/metrics/recall/recall.py index 8522cfcf6..1c20afc46 100644 --- a/metrics/recall/recall.py +++ b/metrics/recall/recall.py @@ -132,4 +132,4 @@ def _compute( sample_weight=sample_weight, zero_division=zero_division, ) - return {"recall": float(score) if score.size == 1 else score} + return {"recall": score if getattr(score, "size", 1) > 1 else float(score)}