Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification metrics overhaul: input formatting standardization (1/n) #4837

Merged
merged 53 commits into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6959ea0
Add stuff
tadejsv Nov 24, 2020
0679015
Change metrics documentation layout
tadejsv Nov 24, 2020
55fdaaf
Change testing utils
tadejsv Nov 24, 2020
5cbf56a
Replace len(*.shape) with *.ndim
tadejsv Nov 24, 2020
9c33d0b
More descriptive error message for input formatting
tadejsv Nov 24, 2020
6562205
Replace movedim with permute
tadejsv Nov 24, 2020
a04a71e
Style changes in error messages
tadejsv Nov 25, 2020
eaac5d7
More error message style improvements
tadejsv Nov 25, 2020
c1108f0
Fix typo in docs
tadejsv Nov 25, 2020
277769b
Add more descriptive variable names in utils
tadejsv Nov 25, 2020
4849298
Change internal var names
tadejsv Nov 25, 2020
22906a4
Merge remote-tracking branch 'upstream/master' into cls_metrics_input…
tadejsv Nov 25, 2020
02bd636
Break down error checking for inputs into separate functions
tadejsv Nov 25, 2020
f97145b
Remove the (N, ..., C) option in MD-MC
tadejsv Nov 25, 2020
536feaf
Simplify select_topk
tadejsv Nov 25, 2020
4241d7c
Remove detach for inputs
tadejsv Nov 25, 2020
86d6c4d
Fix typos
tadejsv Nov 25, 2020
bb11677
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Nov 25, 2020
cde3997
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 26, 2020
05a54da
Update docs/source/metrics.rst
tadejsv Nov 26, 2020
9a43a5e
Minor error message changes
tadejsv Nov 26, 2020
3f4ad3c
Update pytorch_lightning/metrics/utils.py
tadejsv Nov 26, 2020
a654e6a
Reuse case from validation in formatting
tadejsv Nov 26, 2020
7b2ef2b
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 26, 2020
16ab8f7
Refactor code in _input_format_classification
tadejsv Nov 27, 2020
558276f
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 27, 2020
ecffe18
Small improvements
tadejsv Nov 27, 2020
725c7dd
PEP 8
tadejsv Nov 27, 2020
41ad0b7
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ca13e76
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ede2c7f
Update docs/source/metrics.rst
tadejsv Nov 27, 2020
c6e4de4
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
201d0de
Apply suggestions from code review
tadejsv Nov 27, 2020
f08edbc
Alphabetical reordering of regression metrics
tadejsv Nov 27, 2020
523bae3
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 27, 2020
db24fae
Merge branch 'master' into cls_metrics_input_formatting
Borda Nov 27, 2020
35e3eff
Change default value of top_k and add error checking
tadejsv Nov 28, 2020
dd6f8ea
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 28, 2020
c28aadf
Extract basic validation into separate function
tadejsv Nov 28, 2020
0cb0eac
Update desciption of parameters in input formatting
tadejsv Nov 29, 2020
28acf4c
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 30, 2020
8e7a85a
Apply suggestions from code review
tadejsv Nov 30, 2020
829155e
Check that probabilities in preds sum to 1 (for MC)
tadejsv Nov 30, 2020
768879d
Fix coverage
tadejsv Nov 30, 2020
15ef14d
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Dec 2, 2020
1568970
Merge branch 'master' into cls_metrics_input_formatting
tchaton Dec 3, 2020
a9fa730
Merge with master and resolve conflicts
tadejsv Dec 6, 2020
44ad276
Merge branch 'master' into cls_metrics_input_formatting
Borda Dec 6, 2020
96d40c8
Minor changes
tadejsv Dec 6, 2020
cca430a
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Dec 6, 2020
f3c47f9
Fix edge case and simplify testing
tadejsv Dec 6, 2020
ecb5472
Merge branch 'master' into cls_metrics_input_formatting
Borda Dec 7, 2020
4a71a56
Merge branch 'master' into cls_metrics_input_formatting
Borda Dec 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 145 additions & 84 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,53 +196,76 @@ Metric API
.. autoclass:: pytorch_lightning.metrics.Metric
:noindex:

*************
Class metrics
*************
***************************
Class vs Functional Metrics
Borda marked this conversation as resolved.
Show resolved Hide resolved
***************************

The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.

Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.

**********************
Classification Metrics
----------------------
**********************

Accuracy
~~~~~~~~
Input types
-----------

.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (``N`` stands for the batch size and ``C`` for number of classes):

Precision
~~~~~~~~~
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 20, 10, 10, 10, 10

.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"

Recall
~~~~~~
.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.

.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types

FBeta
~~~~~
.. testcode::

.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])

F1
~~
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])

.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])

ConfusionMatrix
~~~~~~~~~~~~~~~
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:
In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label. For example, if both predictions and targets are 1d
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
inputs as 2-class (multi-dimensional) multi-class inputs.

PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~
For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument.

.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
Class Metrics (Classification)
------------------------------

Accuracy
~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:

AveragePrecision
Expand All @@ -251,67 +274,51 @@ AveragePrecision
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:

ROC
~~~
ConfusionMatrix
~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.ROC
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

Regression Metrics
------------------

MeanSquaredError
~~~~~~~~~~~~~~~~
F1
~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:

FBeta
~~~~~

MeanAbsoluteError
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

Precision
~~~~~~~~~

MeanSquaredLogError
~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:

PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~

ExplainedVariance
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
:noindex:

Recall
~~~~~~

PSNR
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.PSNR
.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:

ROC
~~~

SSIM
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:

******************
Functional Metrics
******************

The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.

Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface.

Classification
--------------
Functional Metrics (Classification)
-----------------------------------

accuracy [func]
~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -417,6 +424,12 @@ recall [func]
.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
:noindex:

select_topk [func]
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.utils.select_topk
:noindex:


stat_scores [func]
~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -445,9 +458,57 @@ to_onehot [func]
.. autofunction:: pytorch_lightning.metrics.utils.to_onehot
:noindex:

******************
Regression Metrics
******************
tadejsv marked this conversation as resolved.
Show resolved Hide resolved

Class Metrics (Regression)
--------------------------

Regression
----------
ExplainedVariance
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
:noindex:


MeanAbsoluteError
~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
:noindex:


MeanSquaredError
~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
:noindex:


MeanSquaredLogError
~~~~~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
:noindex:


PSNR
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.PSNR
:noindex:


SSIM
~~~~

.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:


Functional Metrics (Regression)
-------------------------------

explained_variance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -470,17 +531,17 @@ mean_squared_error [func]
:noindex:


psnr [func]
~~~~~~~~~~~
mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.psnr
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
:noindex:


mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
psnr [func]
~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:


Expand All @@ -490,22 +551,22 @@ ssim [func]
.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:


***
NLP
---
***

bleu_score [func]
~~~~~~~~~~~~~~~~~
-----------------

.. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score
:noindex:


********
Pairwise
--------
********

embedding_similarity [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
---------------------------

.. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity
:noindex:
Loading