From b5d53f26e90ea42e6c413ae69902647907c7ddbe Mon Sep 17 00:00:00 2001 From: Igor Sugak Date: Fri, 18 Oct 2024 21:01:42 -0700 Subject: [PATCH] replace uses of np.ndarray with npt.NDArray (#681) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/681 X-link: https://github.com/pytorch/captum/pull/1389 X-link: https://github.com/pytorch/botorch/pull/2586 X-link: https://github.com/pytorch/audio/pull/3846 This replaces uses of `numpy.ndarray` in type annotations with `numpy.typing.NDArray`. In Numpy-1.24.0+ `numpy.ndarray` is annotated as generic type. Without template parameters it triggers static analysis errors: ```counterexample Generic type `ndarray` expects 2 type parameters. ``` `numpy.typing.NDArray` is an alias that provides default template parameters. Reviewed By: ryanthomasjohnson Differential Revision: D64619891 fbshipit-source-id: dffc096b1ce90d11e73d475f0bbcb8867ed9ef01 --- opacus/accountants/analysis/prv/prvs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/opacus/accountants/analysis/prv/prvs.py b/opacus/accountants/analysis/prv/prvs.py index b4ffcf76..957ad336 100644 --- a/opacus/accountants/analysis/prv/prvs.py +++ b/opacus/accountants/analysis/prv/prvs.py @@ -16,6 +16,7 @@ from typing import Tuple import numpy as np +import numpy.typing as npt from scipy import integrate from scipy.special import erfc @@ -133,7 +134,7 @@ def mean(self) -> float: @dataclass class DiscretePRV: - pmf: np.ndarray + pmf: npt.NDArray domain: Domain def __len__(self) -> int: