Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero_division option to the precision, recall, f1, fbeta. #2198

Merged
merged 40 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
462b276
Add support of zero_division parameter
i-aki-y Nov 2, 2023
a79cea9
fix overlooked
i-aki-y Nov 2, 2023
b698f87
Fix type error
i-aki-y Nov 2, 2023
c6fbe6a
Fix type error
i-aki-y Nov 2, 2023
3fbd633
Fix missing comma
i-aki-y Nov 2, 2023
cf6d3f1
Doc fix wrong math expression
i-aki-y Nov 4, 2023
cc69671
Merge branch 'master' into add-zerodivision-support
Borda Nov 7, 2023
bea5bd2
Fixed StatScores to have zero_division
i-aki-y Nov 17, 2023
ef881c5
Merge branch 'master' into add-zerodivision-support
Borda Nov 26, 2023
ab8081a
Merge branch 'master' into add-zerodivision-support
Borda Nov 28, 2023
9d943a6
Merge branch 'master' into add-zerodivision-support
Borda Dec 18, 2023
cec6adc
fix missing zero_division arg
i-aki-y Dec 19, 2023
f68ec72
fix device mismatch
i-aki-y Dec 19, 2023
306ebbc
Merge branch 'master' into add-zerodivision-support
Borda Dec 19, 2023
737443d
Merge branch 'master' into add-zerodivision-support
Borda Dec 21, 2023
8b1f1ec
Merge branch 'master' into add-zerodivision-support
Borda Jan 9, 2024
1d90bc9
Merge branch 'master' into add-zerodivision-support
i-aki-y Jan 21, 2024
198bcbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2024
389dfa7
use scikit-learn 1.4.0
i-aki-y Jan 21, 2024
acf6799
fix scikit-learn min ver
i-aki-y Jan 21, 2024
6f6cc86
fix for new sklearn version
SkafteNicki Jan 21, 2024
d17744a
fix scikit-learn requirements
i-aki-y Jan 22, 2024
60f3e4d
fix incorrect requirements condition
i-aki-y Jan 23, 2024
98e8bb8
fix test code to pass in multiple sklearn versions
i-aki-y Jan 25, 2024
8fd617c
Merge branch 'master' into add-zerodivision-support
i-aki-y Jan 25, 2024
f539d46
Merge branch 'master' into add-zerodivision-support
Borda Jan 30, 2024
5bd64de
Merge branch 'master' into add-zerodivision-support
Borda Feb 6, 2024
306f4fe
Merge branch 'master' into add-zerodivision-support
Borda Feb 15, 2024
bb4641d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
d32ce82
Merge branch 'master' into add-zerodivision-support
Borda Feb 26, 2024
fa20b74
Merge branch 'master' into add-zerodivision-support
i-aki-y Mar 16, 2024
36a82e3
Merge branch 'master' into add-zerodivision-support
Borda Mar 29, 2024
b738292
Merge branch 'master' into add-zerodivision-support
Borda Apr 17, 2024
67ff903
Merge branch 'master' into add-zerodivision-support
Borda Apr 23, 2024
0bbe57a
Merge branch 'master' into add-zerodivision-support
SkafteNicki May 2, 2024
640b3fe
changelog
SkafteNicki May 2, 2024
48f2d1c
better docstring
SkafteNicki May 2, 2024
905cc7e
add jaccardindex
SkafteNicki May 2, 2024
b82bc0e
fix tests
SkafteNicki May 2, 2024
e26e748
skip for old sklearn versions
SkafteNicki May 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483))


- Added `zero_division` argument to selected classification metrics ([#2198](https://github.com/Lightning-AI/torchmetrics/pull/2198))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))
Expand Down
3 changes: 2 additions & 1 deletion requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ pyGithub ==2.3.0
fire <=0.6.0

cloudpickle >1.3, <=3.0.0
scikit-learn >=1.1.1, <1.4.0
scikit-learn >=1.1.1, <1.3.0; python_version < "3.9"
scikit-learn >=1.4.0, <1.5.0; python_version >= "3.9"
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
cachier ==3.0.0
90 changes: 76 additions & 14 deletions src/torchmetrics/classification/f_beta.py

Large diffs are not rendered by default.

20 changes: 17 additions & 3 deletions src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class BinaryJaccardIndex(BinaryConfusionMatrix):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division:
Value to replace when there is a division by zero. Should be `0` or `1`.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (preds is int tensor):
Expand Down Expand Up @@ -97,15 +99,17 @@ def __init__(
threshold: float = 0.5,
ignore_index: Optional[int] = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: Any,
) -> None:
super().__init__(
threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs
)
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average="binary")
return _jaccard_index_reduce(self.confmat, average="binary", zero_division=self.zero_division)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -187,6 +191,8 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix):

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division:
Value to replace when there is a division by zero. Should be `0` or `1`.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (pred is integer tensor):
Expand Down Expand Up @@ -224,6 +230,7 @@ def __init__(
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
ignore_index: Optional[int] = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -233,10 +240,13 @@ def __init__(
_multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average)
self.validate_args = validate_args
self.average = average
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index)
return _jaccard_index_reduce(
self.confmat, average=self.average, ignore_index=self.ignore_index, zero_division=self.zero_division
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down Expand Up @@ -319,6 +329,8 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix):

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division:
Value to replace when there is a division by zero. Should be `0` or `1`.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example (preds is int tensor):
Expand Down Expand Up @@ -354,6 +366,7 @@ def __init__(
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
ignore_index: Optional[int] = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -368,10 +381,11 @@ def __init__(
_multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average)
self.validate_args = validate_args
self.average = average
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average)
return _jaccard_index_reduce(self.confmat, average=self.average, zero_division=self.zero_division)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
88 changes: 71 additions & 17 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce
from torchmetrics.functional.classification.precision_recall import (
_precision_recall_reduce,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
Expand All @@ -42,7 +44,7 @@ class BinaryPrecision(BinaryStatScores):

Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
encountered a score of 0 is returned.
encountered a score of `zero_division` (0 or 1, default is 0) is returned.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down Expand Up @@ -73,6 +75,7 @@ class BinaryPrecision(BinaryStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -112,7 +115,14 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average
"precision",
tp,
fp,
tn,
fn,
average="binary",
multidim_average=self.multidim_average,
zero_division=self.zero_division,
)

def plot(
Expand Down Expand Up @@ -165,8 +175,8 @@ class MulticlassPrecision(MulticlassStatScores):

Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down Expand Up @@ -217,6 +227,7 @@ class MulticlassPrecision(MulticlassStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -269,7 +280,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
"precision",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
top_k=self.top_k,
zero_division=self.zero_division,
)

def plot(
Expand Down Expand Up @@ -322,8 +341,8 @@ class MultilabelPrecision(MultilabelStatScores):

Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down Expand Up @@ -373,6 +392,7 @@ class MultilabelPrecision(MultilabelStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -423,7 +443,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
"precision",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
multilabel=True,
zero_division=self.zero_division,
)

def plot(
Expand Down Expand Up @@ -476,7 +504,7 @@ class BinaryRecall(BinaryStatScores):

Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
encountered a score of 0 is returned.
encountered a score of `zero_division` (0 or 1, default is 0) is returned.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down Expand Up @@ -507,6 +535,7 @@ class BinaryRecall(BinaryStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -546,7 +575,14 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average
"recall",
tp,
fp,
tn,
fn,
average="binary",
multidim_average=self.multidim_average,
zero_division=self.zero_division,
)

def plot(
Expand Down Expand Up @@ -599,8 +635,8 @@ class MulticlassRecall(MulticlassStatScores):

Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down Expand Up @@ -650,6 +686,7 @@ class MulticlassRecall(MulticlassStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -702,7 +739,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
"recall",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
top_k=self.top_k,
zero_division=self.zero_division,
)

def plot(
Expand Down Expand Up @@ -755,8 +800,8 @@ class MultilabelRecall(MultilabelStatScores):

Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down Expand Up @@ -805,6 +850,7 @@ class MultilabelRecall(MultilabelStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.

Example (preds is int tensor):
>>> from torch import tensor
Expand Down Expand Up @@ -855,7 +901,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
"recall",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
multilabel=True,
zero_division=self.zero_division,
)

def plot(
Expand Down
16 changes: 13 additions & 3 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ def __init__(
validate_args: bool = True,
**kwargs: Any,
) -> None:
zero_division = kwargs.pop("zero_division", 0)
super(_AbstractStatScores, self).__init__(**kwargs)
if validate_args:
_binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index)
_binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
self.threshold = threshold
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args
self.zero_division = zero_division

self._create_state(size=1, multidim_average=multidim_average)

Expand Down Expand Up @@ -313,15 +315,19 @@ def __init__(
validate_args: bool = True,
**kwargs: Any,
) -> None:
zero_division = kwargs.pop("zero_division", 0)
super(_AbstractStatScores, self).__init__(**kwargs)
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
_multiclass_stat_scores_arg_validation(
num_classes, top_k, average, multidim_average, ignore_index, zero_division
)
self.num_classes = num_classes
self.top_k = top_k
self.average = average
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args
self.zero_division = zero_division

self._create_state(
size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average
Expand Down Expand Up @@ -461,15 +467,19 @@ def __init__(
validate_args: bool = True,
**kwargs: Any,
) -> None:
zero_division = kwargs.pop("zero_division", 0)
super(_AbstractStatScores, self).__init__(**kwargs)
if validate_args:
_multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
_multilabel_stat_scores_arg_validation(
num_labels, threshold, average, multidim_average, ignore_index, zero_division
)
self.num_labels = num_labels
self.threshold = threshold
self.average = average
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args
self.zero_division = zero_division

self._create_state(size=num_labels, multidim_average=multidim_average)

Expand Down
Loading
Loading