From 55b72394269f3d96bc671851bb406fe912c42d86 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Thu, 10 Oct 2024 17:28:11 +0200 Subject: [PATCH] Smaller fixes to docs + integration tests (#2775) * smaller fixes * fix plot fairness --- docs/source/pages/implement.rst | 22 +++++++++---------- docs/source/pages/lightning.rst | 19 ++++++++-------- .../classification/group_fairness.py | 4 ++-- src/torchmetrics/clustering/dunn_index.py | 2 +- tests/integrations/test_lightning.py | 6 ++--- 5 files changed, 26 insertions(+), 27 deletions(-) diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index f7da4aa8ba7..1620ce29cd9 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -9,8 +9,8 @@ Implementing a Metric ##################### -While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. Therefore, we -have made it easy to implement your own metric and possible contribute it to ``torchmetrics``. This page will guide +While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. We have made it +easy to implement your own metric, and you can contribute it to ``torchmetrics`` if you wish. This page will guide you through the process. If you afterwards are interested in contributing your metric to ``torchmetrics``, please read the `contribution guidelines `_ and see this :ref:`section `. @@ -63,7 +63,7 @@ A few important things to note: * The ``dist_reduce_fx`` argument to ``add_state`` is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use ``"sum"`` to sum the metric states across batches. A couple of - build in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is + built-in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is also supported. * In ``update`` we do not return anything but instead update the metric states in-place. @@ -101,7 +101,7 @@ because we need to calculate the rank of the predictions and targets. preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) # some intermediate computation... - r_preds, r_target = _rank_data(preds), _rank_dat(target) + r_preds, r_target = _rank_data(preds), _rank_data(target) preds_diff = r_preds - r_preds.mean(0) target_diff = r_target - r_target.mean(0) cov = (preds_diff * target_diff).mean(0) @@ -118,10 +118,10 @@ A few important things to note for this example: * When working with list states, The ``update(...)`` method should append the batch states to the list. -* In the the ``compute`` method the list states behave a bit differently dependeding on weather you are running in +* In the the ``compute`` method the list states behave a bit differently dependeding on whether you are running in distributed mode or not. In non-distributed mode the list states will be a list of tensors, while in distributed mode the list have already been concatenated into a single tensor. For this reason, we recommend always using the - ``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless + ``dim_zero_cat`` helper function which will standardize the list states to be a single concatenated tensor regardless of the mode. * Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care @@ -179,7 +179,7 @@ used, that provides the common plotting functionality for most metrics in torchm return self._plot(val, ax) If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details -on the plotting API, see the this :ref:`page ` . In addti +on the plotting API, see the this :ref:`page ` . ******************************* Internal implementation details @@ -247,7 +247,7 @@ as long as they serve a general purpose. However, to keep all our metrics consis and tests gets formatted in the following way: 1. Start by reading our `contribution guidelines `_. -2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should +2. First implement the functional backend. This takes care of all the logic that goes into the metric. The code should be put into a single file placed under ``src/torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of metric (classification, regression, text etc.) and ``new_metric`` is the name of the metric. In this file, there should be the following three functions: @@ -259,7 +259,7 @@ and tests gets formatted in the following way: .. note:: The `functional mean squared error `_ - metric is a great example of this division of logic. + metric is a is a great example of how to divide the logic. 3. In a corresponding file placed in ``src/torchmetrics/"domain"/"new_metric".py`` create the module interface: @@ -283,12 +283,12 @@ and tests gets formatted in the following way: both the functional and module interface. 2. In that file, start by defining a number of test inputs that your metric should be evaluated on. 3. Create a testclass ``class NewMetric(MetricTester)`` that inherits from ``tests.helpers.testers.MetricTester``. - This testclass should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that + This test class should essentially implement the ``test_"new_metric"_class`` and ``test_"new_metric"_fn`` methods that respectively tests the module interface and the functional interface. 4. The testclass should be parameterized (using ``@pytest.mark.parametrize``) by the different test inputs defined initially. Additionally, the ``test_"new_metric"_class`` method should also be parameterized with an ``ddp`` parameter such that it gets tested in a distributed setting. If your metric has additional parameters, then make sure to also parameterize these - such that different combinations of inputs and parameters gets tested. + so that different combinations of inputs and parameters get tested. 5. (optional) If your metric raises any exception, please add tests that showcase this. .. note:: diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index a7678f3e2a5..a96196396b8 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -15,8 +15,8 @@ framework designed for scaling models without boilerplate. .. note:: - TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend to always keep both frameworks - up-to-date for the best experience. + TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend always + keeping both frameworks up-to-date for the best experience. While TorchMetrics was built to be used with native PyTorch, using TorchMetrics with Lightning offers additional benefits: @@ -74,7 +74,7 @@ method, Lightning will log the metric based on ``on_step`` and ``on_epoch`` flag ``sync_dist``, ``sync_dist_group`` and ``reduce_fx`` flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class contains its own distributed synchronization logic. - This however is only true for metrics that inherit the base class ``Metric``, + This, however is only true for metrics that inherit the base class ``Metric``, and thus the functional metric API provides no support for in-built distributed synchronization or reduction functions. @@ -202,7 +202,7 @@ Common Pitfalls The following contains a list of pitfalls to be aware of: * Logging a `MetricCollection` object directly using ``self.log_dict`` is only supported if all metrics in the - collection returns a scalar tensor. If any of the metrics in the collection returns a non-scalar tensor, + collection return a scalar tensor. If any of the metrics in the collection return a non-scalar tensor, the logging will fail. This can especially happen when either nesting multiple ``MetricCollection`` objects or when using wrapper metrics such as :class:`~torchmetrics.wrappers.ClasswiseWrapper`, :class:`~torchmetrics.wrappers.MinMaxMetric` etc. inside a ``MetricCollection`` since all these wrappers return @@ -290,7 +290,7 @@ The following contains a list of pitfalls to be aware of: self.log('val_fid', val) * Calling ``self.log("val", self.metric(preds, target))`` with the intention of logging the metric object. Because - ``self.metric(preds, target)`` corresponds to calling the forward method, this will return a tensor and not the + ``self.metric(preds, target)`` corresponds to calling the ``forward`` method, this will return a tensor and not the metric object. Such logging will be wrong in this case. Instead, it is essential to separate into several lines: .. testcode:: python @@ -303,7 +303,8 @@ The following contains a list of pitfalls to be aware of: self.accuracy(preds, y) # compute metrics self.log('train_acc_step', self.accuracy) # log metric object -* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in itself is not a metric - i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from :class:`~torch.nn.ModuleList`. Thus, - to log the output of this metric one needs to manually log the returned values (not the object) using ``self.log`` - and for epoch level logging this should be done in the appropriate ``on_{train|validation|test}_epoch_end`` method. +* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in + itself is not a metric i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from + :class:`~torch.nn.ModuleList`. Thus, to log the output of this metric one needs to manually log the returned values + (not the object) using ``self.log`` and for epoch level logging this should be done in the appropriate + ``on_{train|validation|test}_epoch_end`` method. diff --git a/src/torchmetrics/classification/group_fairness.py b/src/torchmetrics/classification/group_fairness.py index 8575d4d464b..8e38b24faeb 100644 --- a/src/torchmetrics/classification/group_fairness.py +++ b/src/torchmetrics/classification/group_fairness.py @@ -302,11 +302,11 @@ def plot( .. plot:: :scale: 75 - >>> from torch import rand, randint + >>> from torch import ones, rand, randint >>> # Example plotting a single value >>> from torchmetrics.classification import BinaryFairness >>> metric = BinaryFairness(2) - >>> metric.update(rand(20), randint(2, (20,)), randint(2, (20,))) + >>> metric.update(rand(20), randint(2, (20,)), ones(20).long()) >>> fig_, ax_ = metric.plot() .. plot:: diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 89565261f3e..9373db1045e 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -121,7 +121,7 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> metric = DunnIndex(p=2) >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.randn(10, 3), torch.randint(0, 2, (10,)))) + ... values.append(metric(torch.randn(50, 3), torch.randint(0, 2, (50,)))) >>> fig_, ax_ = metric.plot(values) """ diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 302a4353f76..05799b2711d 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -530,8 +530,7 @@ def training_step(self, batch, batch_idx): preds = torch.randint(0, 5, (100,), device=batch.device) target = torch.randint(0, 5, (100,), device=batch.device) - self.train_metrics.update(preds, target) - batch_values = self.train_metrics.compute() + batch_values = self.train_metrics(preds, target) self.log_dict(batch_values, on_step=True, on_epoch=False) return {"loss": loss} @@ -589,8 +588,7 @@ def training_step(self, batch, batch_idx): preds = torch.randint(0, 5, (100,), device=batch.device) target = torch.randint(0, 5, (100,), device=batch.device) - self.train_metrics.update(preds, target) - batch_values = self.train_metrics.compute() + batch_values = self.train_metrics(preds, target) self.log_dict(batch_values, on_step=True, on_epoch=False) return {"loss": loss}