From a6c9849e61228be20158dee03fe687456cbb3022 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 26 Apr 2024 20:42:45 -0400 Subject: [PATCH] Relax numpy upper version cap (#1172) * Relax numpy upper version cap In #1012 we added an upper version cap to numpy to prevent it from installing numpy 2.0 before we confirmed that rustworkx was compatible with it. Now that numpy 2.0.0rc1 has been released we're able to confirm that rustworkx works fine with numpy 2.0. This commit raises the upper bound on the numpy version to < 3 to enable installing numpy 2.0 with rustworkx. * Handle new __array__ API in numpy 2.0 While we didn't have any test coverage for this looking at the numpy 2.0 migration guide one thing we'll have to handle is the new copy kwarg on array: https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword This commit updates the sole use of __array__ we have on custom sequence return types so that if copy=False is passed in we raise a ValueError. Additionally, the dtype handling is done directly in the rustworkx code now to ensure we don't have any issues with numpy 2.0. * Fix __array__ stubs * Update src/iterators.rs * Pin ruff to 0.4.1 --------- Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com> --- .github/workflows/main.yml | 2 +- rustworkx/rustworkx.pyi | 2 +- setup.py | 2 +- src/iterators.rs | 28 ++++++++++++++++++++++------ tests/test_custom_return_types.py | 10 ++++++++++ 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 97f925fd1..4fa06758d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -25,7 +25,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: 3.8 - - run: pip install -U ruff black~=22.0 + - run: pip install -U ruff==0.4.1 black~=22.0 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 991631460..2edcdde67 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -993,7 +993,7 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC): def __len__(self) -> int: ... def __ne__(self, other: object) -> bool: ... def __setstate__(self, state: Sequence[_T_co]) -> None: ... - def __array__(self, _dt: np.dtype | None = ...) -> np.ndarray: ... + def __array__(self, dtype: np.dtype | None = ..., copy: bool | None = ...) -> np.ndarray: ... def __iter__(self) -> Iterator[_T_co]: ... def __reversed__(self) -> Iterator[_T_co]: ... diff --git a/setup.py b/setup.py index 82a390d7e..40bf25ca8 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def readme(): PKG_NAME = os.getenv("RUSTWORKX_PKG_NAME", "rustworkx") PKG_VERSION = "0.15.0" PKG_PACKAGES = ["rustworkx", "rustworkx.visualization"] -PKG_INSTALL_REQUIRES = ["numpy>=1.16.0,<2"] +PKG_INSTALL_REQUIRES = ["numpy>=1.16.0,<3"] RUST_EXTENSIONS = [RustExtension("rustworkx.rustworkx", "Cargo.toml", binding=Binding.PyO3, debug=rustworkx_debug)] RUST_OPTS ={"bdist_wheel": {"py_limited_api": "cp38"}} diff --git a/src/iterators.rs b/src/iterators.rs index 5ed425342..f766dfc3b 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -46,10 +46,11 @@ use num_bigint::BigUint; use rustworkx_core::dictmap::*; use ndarray::prelude::*; -use numpy::{IntoPyArray, PyArrayDescr}; -use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError}; +use numpy::IntoPyArray; +use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError, PyValueError}; use pyo3::gc::PyVisit; use pyo3::prelude::*; +use pyo3::types::IntoPyDict; use pyo3::types::PySlice; use pyo3::PyTraverseError; @@ -601,11 +602,26 @@ macro_rules! custom_vec_iter_impl { fn __array__( &self, py: Python, - _dt: Option<&Bound>, + dtype: Option, + copy: Option, ) -> PyResult { - // Note: we accept the dtype argument on the signature but - // effictively do nothing with it to let Numpy handle the conversion itself - self.$data.convert_to_pyarray(py) + if copy == Some(false) { + return Err(PyValueError::new_err( + "A copy is needed to return an array from this object.", + )); + } + let res = self.$data.convert_to_pyarray(py)?; + Ok(match dtype { + Some(dtype) => { + let numpy_mod = py.import_bound("numpy")?; + let args = (res,); + let kwargs = [("dtype", dtype)].into_py_dict_bound(py); + numpy_mod + .call_method("asarray", args, Some(&kwargs))? + .into() + } + None => res, + }) } fn __traverse__(&self, vis: PyVisit) -> Result<(), PyTraverseError> { diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index 504a9f734..725cf73ed 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -198,6 +198,16 @@ def test_numpy_conversion(self): res = self.dag.node_indexes() np.testing.assert_array_equal(np.asarray(res, dtype=np.uintp), np.array([0, 1])) + def test_numpy_conversion_copy_false(self): + res = self.dag.node_indices() + with self.assertRaises(ValueError): + res.__array__(copy=False) + + def test_numpy_conversion_dtype_complex(self): + res = self.dag.node_indices() + array = res.__array__(dtype=complex) + self.assertEqual(np.dtype(complex), array.dtype) + class TestNodesCountMapping(unittest.TestCase): def setUp(self):