Skip to content

Commit

Permalink
Merge branch 'master' into newmetric/nrmse
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Apr 24, 2024
2 parents 1100cda + 3d52192 commit 8b962d0
Show file tree
Hide file tree
Showing 56 changed files with 1,249 additions and 111 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand All @@ -46,7 +46,7 @@ repos:
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.16.26
rev: v1.20.7
hooks:
- id: typos
# empty to do not write fixes
Expand Down Expand Up @@ -112,7 +112,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.3.5
hooks:
# try to fix what is possible
- id: ruff
Expand Down
36 changes: 34 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `GeneralizedDiceScore` to segmentation package ([#1090](https://github.com/Lightning-AI/metrics/pull/1090))


- Added `SensitivityAtSpecificity` metric to classification subpackage ([#2217](https://github.com/Lightning-AI/torchmetrics/pull/2217))


Expand All @@ -24,21 +27,50 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381))


- Added a new segmentation metric `MeanIoU` ([#1236](https://github.com/PyTorchLightning/metrics/pull/1236))


- Added `pretty-errors` for improving error prints ([#2431](https://github.com/Lightning-AI/torchmetrics/pull/2431))


- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))


### Deprecated

-


### Fixed

- Fix getitem for metric collection when prefix/postfix is set ([#2430](https://github.com/Lightning-AI/torchmetrics/pull/2430))


- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed list synchronization with partly empty lists ([#2468](https://github.com/Lightning-AI/torchmetrics/pull/2468))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))


- Fixed `BootStrapper` wrapper not working with `kwargs` provided argument ([#2503](https://github.com/Lightning-AI/torchmetrics/pull/2503))


- Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501))


- Fixed cornercase in `binary_average_precision` when only negative samples are provided ([#2507](https://github.com/Lightning-AI/torchmetrics/pull/2507))


## [1.3.2] - 2024-03-18

### Fixed
Expand Down Expand Up @@ -111,7 +143,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089))
- Use `arange` and repeat for deterministic bincount ([#2184](https://github.com/Lightning-AI/torchmetrics/pull/2184))
- Use `torch` range func and repeat for deterministic bincount ([#2184](https://github.com/Lightning-AI/torchmetrics/pull/2184))

### Removed

Expand Down
12 changes: 8 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: test clean docs env data
.PHONY: clean test pull-template docs env data

export FREEZE_REQUIREMENTS=1
# assume you have installed need packages
Expand Down Expand Up @@ -28,10 +28,14 @@ test: clean env data
cd tests && python -m pytest unittests -v --cov=torchmetrics
cd tests && python -m coverage report

docs: clean
pip install -e . --quiet -r requirements/_docs.txt
pull-template:
pip install -q awscli
aws s3 sync --no-sign-request s3://sphinx-packages/ dist/

docs: clean pull-template
pip install -e . --quiet -r requirements/_docs.txt -f dist/
# apt-get install -y texlive-latex-extra dvipng texlive-pictures texlive-fonts-recommended cm-super
TOKENIZERS_PARALLELISM=false python -m sphinx -b html -W --keep-going docs/source docs/build
cd docs && make html --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"

env:
pip install -e . -U -r requirements/_devel.txt
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ covers the following domains:
- Multimodal (Image-Text)
- Nominal
- Regression
- Segmentation
- Text

Each domain may require some additional dependencies which can be installed with `pip install torchmetrics[audio]`,
Expand Down
1 change: 1 addition & 0 deletions dockers/ubuntu-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ RUN \
git \
wget \
curl \
zip \
unzip \
g++ \
cmake \
Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ Or directly from conda

retrieval/*

.. toctree::
:maxdepth: 2
:name: segmentation
:caption: Segmentation
:glob:

segmentation/*

.. toctree::
:maxdepth: 2
:name: text
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Normalized Root Mean Squared Error: https://onlinelibrary.wiley.com/doi/abs/10.1111/1365-2478.12109
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
4 changes: 4 additions & 0 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ A few important things to note for this example:
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless
of the mode.

* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care
must be taken when referencing list states. If you require the values after your metric is reset, you must first
copy the attribute to another object (e.g. using `deepcopy.copy`).

*****************
Metric attributes
*****************
Expand Down
22 changes: 22 additions & 0 deletions docs/source/segmentation/generalized_dice.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Generalized Dice Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

######################
Generalized Dice Score
######################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.GeneralizedDiceScore
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.generalized_dice_score
:noindex:
19 changes: 19 additions & 0 deletions docs/source/segmentation/mean_iou.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
.. customcarditem::
:header: Mean Intersection over Union (mIoU)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/object_detection.svg
:tags: segmentation

###################################
Mean Intersection over Union (mIoU)
###################################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.MeanIoU
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.mean_iou
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ packaging >17.1
torch >=1.10.0, <2.3.0
typing-extensions; python_version < '3.9'
lightning-utilities >=0.8.0, <0.12.0
pretty-errors ==1.2.25
2 changes: 1 addition & 1 deletion requirements/multimodal.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

transformers >=4.10.0, <4.40.0
transformers >=4.10.0, <4.41.0
piq <=0.8.0
6 changes: 2 additions & 4 deletions requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

nltk >=3.6, <=3.8.1
tqdm >=4.41.0, <4.67.0
regex >=2021.9.24, <=2023.12.25
transformers >4.4.0, <4.40.0
regex >=2021.9.24, <=2024.4.16
transformers >4.4.0, <4.41.0
mecab-python3 >=1.0.6, <1.1.0
mecab-ko >=1.0.0, <1.1.0
mecab-ko-dic >=1.0.0, <1.1.0
ipadic >=1.0.0, <1.1.0
sentencepiece >=0.2.0, <0.3.0
3 changes: 3 additions & 0 deletions requirements/text_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
huggingface-hub <0.23
sacrebleu >=2.3.0, <2.5.0

mecab-ko >=1.0.0, <1.1.0
mecab-ko-dic >=1.0.0, <1.1.0
3 changes: 3 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

if package_available("pretty_errors"):
import pretty_errors # noqa: F401

if package_available("PIL"):
import PIL

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class labels.
As output to ``forward`` and ``compute`` the metric returns the following output:
- ``bck`` (:class:`~torch.Tensor`): A tensor containing cohen kappa score
- ``bc_kappa`` (:class:`~torch.Tensor`): A tensor containing cohen kappa score
Args:
threshold: Threshold for transforming probability to binary (0,1) predictions
Expand Down
17 changes: 13 additions & 4 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix
from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
Expand Down Expand Up @@ -151,6 +151,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -160,6 +161,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
Figure and Axes object
Expand All @@ -181,7 +184,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down Expand Up @@ -292,6 +295,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -301,6 +305,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
Figure and Axes object
Expand All @@ -322,7 +328,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down Expand Up @@ -436,6 +442,7 @@ def plot(
ax: Optional[_AX_TYPE] = None,
add_text: bool = True,
labels: Optional[List[str]] = None,
cmap: Optional[_CMAP_TYPE] = None,
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -445,6 +452,8 @@ def plot(
ax: An matplotlib axis object. If provided will add plot to that axis
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes
cmap: matplotlib colormap to use for the confusion matrix
https://matplotlib.org/stable/users/explain/colors/colormaps.html
Returns:
Figure and Axes object
Expand All @@ -466,7 +475,7 @@ def plot(
val = val if val is not None else self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels)
fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
return fig, ax


Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class BinaryFairness(_AbstractGroupStatScores):
Args:
num_groups: The number of groups.
task: The task to compute. Can be either ``demographic_parity`` or ``equal_oppotunity`` or ``all``.
task: The task to compute. Can be either ``demographic_parity`` or ``equal_opportunity`` or ``all``.
threshold: Threshold for transforming probability to binary {0,1} predictions.
ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Expand Down
Loading

0 comments on commit 8b962d0

Please sign in to comment.