From 822dba28b206e69fc4b4d632d5f37bfa556c44f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Furkan=20=C3=87elik?= Date: Tue, 16 Apr 2024 13:57:22 +0200 Subject: [PATCH] Support for `torch.float` weighted networks for FID and KID calculations. (#2483) * add support to normalized custom models * - documentation fix - support for float weighted custom networks - support for custom sized input imgs * added dummz feature extractor network to test custom extractor * add dummy feature extractor to tests for testing custom feature extractor * fixed init error * convert int8 tensor imgs to float32 on model side * prehook commit changes * precommit hook changes * fix typing error * fix argument quotation * changelog * Update src/torchmetrics/image/fid.py Co-authored-by: Nicki Skafte Detlefsen * Update src/torchmetrics/image/fid.py Co-authored-by: Nicki Skafte Detlefsen * Update src/torchmetrics/image/fid.py Co-authored-by: Nicki Skafte Detlefsen * Update src/torchmetrics/image/fid.py Co-authored-by: Nicki Skafte Detlefsen * Update src/torchmetrics/image/fid.py Co-authored-by: Nicki Skafte Detlefsen * Update src/torchmetrics/image/fid.py Co-authored-by: Nicki Skafte Detlefsen * try fixing issues in docs --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 3 ++ src/torchmetrics/image/fid.py | 56 ++++++++++++++++++++++++------- src/torchmetrics/image/kid.py | 26 ++++++++++++-- tests/unittests/image/test_fid.py | 13 ++++++- tests/unittests/image/test_kid.py | 13 ++++++- 5 files changed, 93 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a9cd6e4be8..704c9618167 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for calculating segmentation quality and recognition quality in `PanopticQuality` metric ([#2381](https://github.com/Lightning-AI/torchmetrics/pull/2381)) +- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483)) + + ### Changed - Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424)) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 9e148bc94ad..af4b93d7ad5 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -198,6 +198,15 @@ class FrechetInceptionDistance(Metric): flag ``real`` determines if the images should update the statistics of the real distribution or the fake distribution. + Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This + custom feature extractor is expected to have output shape of ``(1, num_features)``. This would change the + used feature extractor from default (Inception v3) to the given network. In case network doesn't have + ``num_features`` attribute, a random tensor will be given to the network to infer feature dimensionality. + Size of this tensor can be controlled by ``input_img_size`` argument and type of the tensor can be controlled + with ``normalize`` argument (``True`` uses float32 tensors and ``False`` uses int8 tensors). In this case, update + method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible + to the custom feature extractor. + This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype` method of the metric. @@ -228,13 +237,20 @@ class FrechetInceptionDistance(Metric): reset_real_features: Whether to also reset the real features. Since in many cases the real dataset does not change, the features can be cached them to avoid recomputing them which is costly. Set this to ``False`` if your dataset does not change. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + normalize: + Argument for controlling the input image dtype normalization: + + - If default feature extractor is used, controls whether input imgs have values in range [0, 1] or not: + + - True: if input imgs have values ranged in [0, 1]. They are cast to int8/byte tensors. + - False: if input imgs have values ranged in [0, 255]. No casting is done. + + - If custom feature extractor module is used, controls type of the input img tensors: - .. note:: - If a custom feature extractor is provided through the `feature` argument it is expected to either have a - attribute called ``num_features`` that indicates the number of features returned by the forward pass or - alternatively we will pass through tensor of shape ``(1, 3, 299, 299)`` and dtype ``torch.uint8``` to the - forward pass and expect a tensor of shape ``(1, num_features)`` as output. + - True: if input imgs are expected to be in the data type of torch.float32. + - False: if input imgs are expected to be in the data type of torch.int8. + input_img_size: tuple of integers. Indicates input img size to the custom feature extractor network if provided. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: ValueError: @@ -284,9 +300,16 @@ def __init__( feature: Union[int, Module] = 2048, reset_real_features: bool = True, normalize: bool = False, + input_img_size: Tuple[int, int, int] = (3, 299, 299), **kwargs: Any, ) -> None: super().__init__(**kwargs) + + if not isinstance(normalize, bool): + raise ValueError("Argument `normalize` expected to be a bool") + self.normalize = normalize + self.used_custom_model = False + if isinstance(feature, int): num_features = feature if not _TORCH_FIDELITY_AVAILABLE: @@ -304,10 +327,14 @@ def __init__( elif isinstance(feature, Module): self.inception = feature + self.used_custom_model = True if hasattr(self.inception, "num_features"): num_features = self.inception.num_features else: - dummy_image = torch.randint(0, 255, (1, 3, 299, 299), dtype=torch.uint8) + if self.normalize: + dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32) + else: + dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8) num_features = self.inception(dummy_image).shape[-1] else: raise TypeError("Got unknown input to argument `feature`") @@ -316,10 +343,6 @@ def __init__( raise ValueError("Argument `reset_real_features` expected to be a bool") self.reset_real_features = reset_real_features - if not isinstance(normalize, bool): - raise ValueError("Argument `normalize` expected to be a bool") - self.normalize = normalize - mx_num_feats = (num_features, num_features) self.add_state("real_features_sum", torch.zeros(num_features).double(), dist_reduce_fx="sum") self.add_state("real_features_cov_sum", torch.zeros(mx_num_feats).double(), dist_reduce_fx="sum") @@ -330,8 +353,15 @@ def __init__( self.add_state("fake_features_num_samples", torch.tensor(0).long(), dist_reduce_fx="sum") def update(self, imgs: Tensor, real: bool) -> None: - """Update the state with extracted features.""" - imgs = (imgs * 255).byte() if self.normalize else imgs + """Update the state with extracted features. + + Args: + imgs: Input img tensors to evaluate. If used custom feature extractor please + make sure dtype and size is correct for the model. + real: Whether given image is real or fake. + + """ + imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs features = self.inception(imgs) self.orig_dtype = features.dtype features = features.double() diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 813e875da45..e080116a33f 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -91,6 +91,12 @@ class KernelInceptionDistance(Metric): flag ``real`` determines if the images should update the statistics of the real distribution or the fake distribution. + Using custom feature extractor is also possible. One can give a torch.nn.Module as `feature` argument. This + custom feature extractor is expected to have output shape of ``(1, num_features)`` This would change the + used feature extractor from default (Inception v3) to the given network. ``normalize`` argument won't have any + effect and update method expects to have the tensor given to `imgs` argument to be in the correct shape and + type that is compatible to the custom feature extractor. + .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` @@ -103,7 +109,7 @@ class KernelInceptionDistance(Metric): As output of `forward` and `compute` the metric returns the following output - ``kid_mean`` (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets - - ``kid_std`` (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets + - ``kid_std`` (:class:`~torch.Tensor`): float scalar tensor with standard deviation value over subsets Args: feature: Either an str, integer or ``nn.Module``: @@ -187,6 +193,8 @@ def __init__( UserWarning, ) + self.used_custom_model = False + if isinstance(feature, (str, int)): if not _TORCH_FIDELITY_AVAILABLE: raise ModuleNotFoundError( @@ -202,6 +210,7 @@ def __init__( self.inception: Module = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)]) elif isinstance(feature, Module): self.inception = feature + self.used_custom_model = True else: raise TypeError("Got unknown input to argument `feature`") @@ -238,8 +247,15 @@ def __init__( self.add_state("fake_features", [], dist_reduce_fx=None) def update(self, imgs: Tensor, real: bool) -> None: - """Update the state with extracted features.""" - imgs = (imgs * 255).byte() if self.normalize else imgs + """Update the state with extracted features. + + Args: + imgs: Input img tensors to evaluate. If used custom feature extractor please + make sure dtype and size is correct for the model. + real: Whether given image is real or fake. + + """ + imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs features = self.inception(imgs) if real: @@ -252,6 +268,10 @@ def compute(self) -> Tuple[Tensor, Tensor]: Implementation inspired by `Fid Score`_ + Returns: + kid_mean (:class:`~torch.Tensor`): float scalar tensor with mean value over subsets + kid_std (:class:`~torch.Tensor`): float scalar tensor with standard deviation value over subsets + """ real_features = dim_zero_cat(self.real_features) fake_features = dim_zero_cat(self.fake_features) diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 252f0d0ebba..f6bbf8e20e5 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -80,8 +80,19 @@ def test_fid_raises_errors_and_warnings(): _ = FrechetInceptionDistance(feature=[1, 2]) +class _DummyFeatureExtractor(Module): + def __init__(self) -> None: + super().__init__() + self.flatten = torch.nn.Flatten() + self.extractor = torch.nn.Linear(3 * 299 * 299, 64) + + def __call__(self, img) -> torch.Tensor: + img = (img / 125.5).float() # Convert int img input to float as Linear layer expects float inputs + return self.extractor(self.flatten(img)) + + @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") -@pytest.mark.parametrize("feature", [64, 192, 768, 2048]) +@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" metric = FrechetInceptionDistance(feature=feature) diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index a754768003c..eea26bb2c0e 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -102,8 +102,19 @@ def test_kid_extra_parameters(): KernelInceptionDistance(coef=-1) +class _DummyFeatureExtractor(Module): + def __init__(self) -> None: + super().__init__() + self.flatten = torch.nn.Flatten() + self.extractor = torch.nn.Linear(3 * 299 * 299, 64) + + def __call__(self, img) -> torch.Tensor: + img = (img / 125.5).float() # Convert int img input to float as Linear layer expects float inputs + return self.extractor(self.flatten(img)) + + @pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") -@pytest.mark.parametrize("feature", [64, 192, 768, 2048]) +@pytest.mark.parametrize("feature", [64, 192, 768, 2048, _DummyFeatureExtractor()]) def test_kid_same_input(feature): """Test that the metric works.""" metric = KernelInceptionDistance(feature=feature, subsets=5, subset_size=2)