Skip to content

Commit

Permalink
Smaller fixes to docs + integration tests (#2775)
Browse files Browse the repository at this point in the history
* smaller fixes
* fix plot fairness
  • Loading branch information
SkafteNicki authored Oct 10, 2024
1 parent 0fe772d commit 55b7239
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 27 deletions.
22 changes: 11 additions & 11 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
Implementing a Metric
#####################

While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. Therefore, we
have made it easy to implement your own metric and possible contribute it to ``torchmetrics``. This page will guide
While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. We have made it
easy to implement your own metric, and you can contribute it to ``torchmetrics`` if you wish. This page will guide
you through the process. If you afterwards are interested in contributing your metric to ``torchmetrics``, please
read the `contribution guidelines <https://lightning.ai/docs/torchmetrics/latest/generated/CONTRIBUTING.html>`_ and
see this :ref:`section <contributing metric>`.
Expand Down Expand Up @@ -63,7 +63,7 @@ A few important things to note:

* The ``dist_reduce_fx`` argument to ``add_state`` is used to specify how the metric states should be reduced between
batches in distributed settings. In this case we use ``"sum"`` to sum the metric states across batches. A couple of
build in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is
built-in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is
also supported.

* In ``update`` we do not return anything but instead update the metric states in-place.
Expand Down Expand Up @@ -101,7 +101,7 @@ because we need to calculate the rank of the predictions and targets.
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
# some intermediate computation...
r_preds, r_target = _rank_data(preds), _rank_dat(target)
r_preds, r_target = _rank_data(preds), _rank_data(target)
preds_diff = r_preds - r_preds.mean(0)
target_diff = r_target - r_target.mean(0)
cov = (preds_diff * target_diff).mean(0)
Expand All @@ -118,10 +118,10 @@ A few important things to note for this example:

* When working with list states, The ``update(...)`` method should append the batch states to the list.

* In the the ``compute`` method the list states behave a bit differently dependeding on weather you are running in
* In the the ``compute`` method the list states behave a bit differently dependeding on whether you are running in
distributed mode or not. In non-distributed mode the list states will be a list of tensors, while in distributed mode
the list have already been concatenated into a single tensor. For this reason, we recommend always using the
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenated tensor regardless
of the mode.

* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care
Expand Down Expand Up @@ -179,7 +179,7 @@ used, that provides the common plotting functionality for most metrics in torchm
return self._plot(val, ax)

If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details
on the plotting API, see the this :ref:`page <plotting>` . In addti
on the plotting API, see the this :ref:`page <plotting>` .

*******************************
Internal implementation details
Expand Down Expand Up @@ -247,7 +247,7 @@ as long as they serve a general purpose. However, to keep all our metrics consis
and tests gets formatted in the following way:

1. Start by reading our `contribution guidelines <https://lightning.ai/docs/torchmetrics/latest/generated/CONTRIBUTING.html>`_.
2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should
2. First implement the functional backend. This takes care of all the logic that goes into the metric. The code should
be put into a single file placed under ``src/torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of
metric (classification, regression, text etc.) and ``new_metric`` is the name of the metric. In this file, there should be the
following three functions:
Expand All @@ -259,7 +259,7 @@ and tests gets formatted in the following way:

.. note::
The `functional mean squared error <https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/regression/mse.py>`_
metric is a great example of this division of logic.
metric is a is a great example of how to divide the logic.

3. In a corresponding file placed in ``src/torchmetrics/"domain"/"new_metric".py`` create the module interface:

Expand All @@ -283,12 +283,12 @@ and tests gets formatted in the following way:
both the functional and module interface.
2. In that file, start by defining a number of test inputs that your metric should be evaluated on.
3. Create a testclass ``class NewMetric(MetricTester)`` that inherits from ``tests.helpers.testers.MetricTester``.
This testclass should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that
This test class should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that
respectively tests the module interface and the functional interface.
4. The testclass should be parameterized (using ``@pytest.mark.parametrize``) by the different test inputs defined initially.
Additionally, the ``test_"new_metric"_class`` method should also be parameterized with an ``ddp`` parameter such that it gets
tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these
such that different combinations of inputs and parameters gets tested.
so that different combinations of inputs and parameters get tested.
5. (optional) If your metric raises any exception, please add tests that showcase this.

.. note::
Expand Down
19 changes: 10 additions & 9 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ framework designed for scaling models without boilerplate.

.. note::

TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks
up-to-date for the best experience.
TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend always
keeping both frameworks up-to-date for the best experience.

While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits:

Expand Down Expand Up @@ -74,7 +74,7 @@ method, Lightning will log the metric based on ``on_step`` and ``on_epoch`` flag
``sync_dist``, ``sync_dist_group`` and ``reduce_fx`` flags from ``self.log(...)`` don't affect the metric logging
in any manner. The metric class contains its own distributed synchronization logic.

This however is only true for metrics that inherit the base class ``Metric``,
This, however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.

Expand Down Expand Up @@ -202,7 +202,7 @@ Common Pitfalls
The following contains a list of pitfalls to be aware of:

* Logging a `MetricCollection` object directly using ``self.log_dict`` is only supported if all metrics in the
collection returns a scalar tensor. If any of the metrics in the collection returns a non-scalar tensor,
collection return a scalar tensor. If any of the metrics in the collection return a non-scalar tensor,
the logging will fail. This can especially happen when either nesting multiple ``MetricCollection`` objects or when
using wrapper metrics such as :class:`~torchmetrics.wrappers.ClasswiseWrapper`,
:class:`~torchmetrics.wrappers.MinMaxMetric` etc. inside a ``MetricCollection`` since all these wrappers return
Expand Down Expand Up @@ -290,7 +290,7 @@ The following contains a list of pitfalls to be aware of:
self.log('val_fid', val)

* Calling ``self.log("val", self.metric(preds, target))`` with the intention of logging the metric object. Because
``self.metric(preds, target)`` corresponds to calling the forward method, this will return a tensor and not the
``self.metric(preds, target)`` corresponds to calling the ``forward`` method, this will return a tensor and not the
metric object. Such logging will be wrong in this case. Instead, it is essential to separate into several lines:

.. testcode:: python
Expand All @@ -303,7 +303,8 @@ The following contains a list of pitfalls to be aware of:
self.accuracy(preds, y) # compute metrics
self.log('train_acc_step', self.accuracy) # log metric object

* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in itself is not a metric
i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from :class:`~torch.nn.ModuleList`. Thus,
to log the output of this metric one needs to manually log the returned values (not the object) using ``self.log``
and for epoch level logging this should be done in the appropriate ``on_{train|validation|test}_epoch_end`` method.
* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in
itself is not a metric i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from
:class:`~torch.nn.ModuleList`. Thus, to log the output of this metric one needs to manually log the returned values
(not the object) using ``self.log`` and for epoch level logging this should be done in the appropriate
``on_{train|validation|test}_epoch_end`` method.
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,11 @@ def plot(
.. plot::
:scale: 75
>>> from torch import rand, randint
>>> from torch import ones, rand, randint
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> metric.update(rand(20), randint(2, (20,)), randint(2, (20,)))
>>> metric.update(rand(20), randint(2, (20,)), ones(20).long())
>>> fig_, ax_ = metric.plot()
.. plot::
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/dunn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> metric = DunnIndex(p=2)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.randn(10, 3), torch.randint(0, 2, (10,))))
... values.append(metric(torch.randn(50, 3), torch.randint(0, 2, (50,))))
>>> fig_, ax_ = metric.plot(values)
"""
Expand Down
6 changes: 2 additions & 4 deletions tests/integrations/test_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,7 @@ def training_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
batch_values = self.train_metrics(preds, target)
self.log_dict(batch_values, on_step=True, on_epoch=False)
return {"loss": loss}

Expand Down Expand Up @@ -589,8 +588,7 @@ def training_step(self, batch, batch_idx):
preds = torch.randint(0, 5, (100,), device=batch.device)
target = torch.randint(0, 5, (100,), device=batch.device)

self.train_metrics.update(preds, target)
batch_values = self.train_metrics.compute()
batch_values = self.train_metrics(preds, target)
self.log_dict(batch_values, on_step=True, on_epoch=False)
return {"loss": loss}

Expand Down

0 comments on commit 55b7239

Please sign in to comment.