Skip to content

Commit

Permalink
Merge branch 'master' into add_mean_iou
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Apr 23, 2024
2 parents 6cd0d74 + 82ab513 commit 57df85a
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed list synchronization with partly empty lists ([#2468](https://github.com/Lightning-AI/torchmetrics/pull/2468))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


Expand Down
2 changes: 1 addition & 1 deletion requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

nltk >=3.6, <=3.8.1
tqdm >=4.41.0, <4.67.0
regex >=2021.9.24, <=2023.12.25
regex >=2021.9.24, <=2024.4.16
transformers >4.4.0, <4.40.0
mecab-python3 >=1.0.6, <1.1.0
ipadic >=1.0.0, <1.1.0
Expand Down
6 changes: 6 additions & 0 deletions src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def plot(
curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None,
score: Optional[Union[Tensor, bool]] = None,
ax: Optional[_AX_TYPE] = None,
labels: Optional[List[str]] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -307,6 +308,7 @@ def plot(
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Expand Down Expand Up @@ -337,6 +339,7 @@ def plot(
ax=ax,
label_names=("False positive rate", "True positive rate"),
name=self.__class__.__name__,
labels=labels,
)


Expand Down Expand Up @@ -456,6 +459,7 @@ def plot(
curve: Optional[Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]] = None,
score: Optional[Union[Tensor, bool]] = None,
ax: Optional[_AX_TYPE] = None,
labels: Optional[List[str]] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -466,6 +470,7 @@ def plot(
will automatically compute the score. The score is computed by using the trapezoidal rule to compute the
area under the curve.
ax: An matplotlib axis object. If provided will add plot to that axis
labels: a list of strings, if provided will be added to the plot to indicate the different classes
Returns:
Figure and Axes object
Expand Down Expand Up @@ -496,6 +501,7 @@ def plot(
ax=ax,
label_names=("False positive rate", "True positive rate"),
name=self.__class__.__name__,
labels=labels,
)


Expand Down
10 changes: 10 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
from torchmetrics.utilities.prints import rank_zero_warn

Expand Down Expand Up @@ -438,6 +439,15 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
input_dict[attr] = [dim_zero_cat(input_dict[attr])]

# cornor case in distributed settings where a rank have not received any data, create empty to concatenate
if (
_TORCH_GREATER_EQUAL_2_1
and reduction_fn == dim_zero_cat
and isinstance(input_dict[attr], list)
and len(input_dict[attr]) == 0
):
input_dict[attr] = [torch.tensor([], device=self.device, dtype=self.dtype)]

output_dict = apply_to_collection(
input_dict,
Tensor,
Expand Down
11 changes: 10 additions & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def plot_curve(
label_names: Optional[Tuple[str, str]] = None,
legend_name: Optional[str] = None,
name: Optional[str] = None,
labels: Optional[List[Union[int, str]]] = None,
) -> _PLOT_OUT_TYPE:
"""Inspired by: https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_plot/roc_curve.py.
Expand All @@ -287,6 +288,7 @@ def plot_curve(
label_names: Tuple containing the names of the x and y axis
legend_name: Name of the curve to be used in the legend
name: Custom name to describe the metric
labels: Optional labels for the different curves that will be added to the plot
Returns:
A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure
Expand All @@ -312,8 +314,15 @@ def plot_curve(
elif (isinstance(x, list) and isinstance(y, list)) or (
isinstance(x, Tensor) and isinstance(y, Tensor) and x.ndim == 2 and y.ndim == 2
):
n_classes = len(x)
if labels is not None and len(labels) != n_classes:
raise ValueError(
"Expected number of elements in arg `labels` to match number of labels in roc curves but "
f"got {len(labels)} and {n_classes}"
)

for i, (x_, y_) in enumerate(zip(x, y)):
label = f"{legend_name}_{i}" if legend_name is not None else str(i)
label = f"{legend_name}_{i}" if legend_name is not None else str(i) if labels is None else str(labels[i])
label += f" AUC={score[i].item():0.3f}" if score is not None else ""
ax.plot(x_.detach().cpu(), y_.detach().cpu(), linestyle="-", linewidth=2, label=label)
ax.legend()
Expand Down
20 changes: 19 additions & 1 deletion tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_PROCESSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -269,11 +270,28 @@ def test_sync_on_compute(sync_on_compute, test_func):
def _test_sync_with_empty_lists(rank):
dummy = DummyListMetric()
val = dummy.compute()
assert val == []
assert torch.allclose(val, tensor([]))


@pytest.mark.DDP()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_empty_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES))


def _test_sync_with_unequal_size_lists(rank):
"""Test that synchronization of list states work even when some ranks have not received any data yet."""
dummy = DummyListMetric()
if rank == 0:
dummy.update(torch.zeros(2))
assert torch.all(dummy.compute() == tensor([0.0, 0.0]))


@pytest.mark.DDP()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_unequal_size_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_unequal_size_lists, range(NUM_PROCESSES))

0 comments on commit 57df85a

Please sign in to comment.