Skip to content

Commit

Permalink
Bugfix so multilabel confusion matrix can plot for 2 or more labels (#…
Browse files Browse the repository at this point in the history
…2858)

* added matplotlib to test requirements
* added new test for plotting in multilabel classifier
* added bugfix
* fix errors
* changelog

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Dec 21, 2024
1 parent a7284e2 commit 8827e64
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed plotting of multilabel confusion matrix ([#2858](https://github.com/PyTorchLightning/metrics/pull/2858))


- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840))


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def plot_confusion_matrix(
fig, axs = plt.subplots(nrows=rows, ncols=cols, constrained_layout=True) if ax is None else (ax.get_figure(), ax)
axs = trim_axs(axs, nb)
for i in range(nb):
ax = axs[i] if rows != 1 and cols != 1 else axs
ax = axs[i] if (rows != 1 or cols != 1) else axs
if fig_label is not None:
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
im = ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach(), cmap=cmap)
Expand Down
10 changes: 10 additions & 0 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,16 @@ def test_multilabel_confusion_matrix_dtype_gpu(self, inputs, dtype):
dtype=dtype,
)

@pytest.mark.parametrize("num_labels", [2, NUM_CLASSES])
def test_multilabel_confusion_matrix_plot(self, num_labels, inputs):
"""Test multilabel cm plots."""
multi_label_confusion_matrix = MultilabelConfusionMatrix(num_labels=num_labels)
preds = target = torch.ones(1, num_labels).int()
multi_label_confusion_matrix.update(preds, target)
fig, ax = multi_label_confusion_matrix.plot()
assert fig is not None
assert ax is not None


def test_warning_on_nan():
"""Test that a warning is given if division by zero happens during normalization of confusion matrix."""
Expand Down

0 comments on commit 8827e64

Please sign in to comment.