Skip to content

Commit

Permalink
Relax numpy upper version cap (#1172)
Browse files Browse the repository at this point in the history
* Relax numpy upper version cap

In #1012 we added an upper version cap to numpy to prevent it from
installing numpy 2.0 before we confirmed that rustworkx was compatible
with it. Now that numpy 2.0.0rc1 has been released we're able to confirm
that rustworkx works fine with numpy 2.0. This commit raises the upper
bound on the numpy version to < 3 to enable installing numpy 2.0 with
rustworkx.

* Handle new __array__ API in numpy 2.0

While we didn't have any test coverage for this looking at the numpy 2.0
migration guide one thing we'll have to handle is the new copy kwarg on
array:

https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword

This commit updates the sole use of __array__ we have on custom sequence
return types so that if copy=False is passed in we raise a ValueError.
Additionally, the dtype handling is done directly in the rustworkx code
now to ensure we don't have any issues with numpy 2.0.

* Fix __array__ stubs

* Update src/iterators.rs

* Pin ruff to 0.4.1

---------

Co-authored-by: Ivan Carvalho <8753214+IvanIsCoding@users.noreply.github.com>
  • Loading branch information
mtreinish and IvanIsCoding committed Apr 27, 2024
1 parent cc01ee8 commit a6c9849
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.8
- run: pip install -U ruff black~=22.0
- run: pip install -U ruff==0.4.1 black~=22.0
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
Expand Down
2 changes: 1 addition & 1 deletion rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ class _RustworkxCustomVecIter(Generic[_T_co], Sequence[_T_co], ABC):
def __len__(self) -> int: ...
def __ne__(self, other: object) -> bool: ...
def __setstate__(self, state: Sequence[_T_co]) -> None: ...
def __array__(self, _dt: np.dtype | None = ...) -> np.ndarray: ...
def __array__(self, dtype: np.dtype | None = ..., copy: bool | None = ...) -> np.ndarray: ...
def __iter__(self) -> Iterator[_T_co]: ...
def __reversed__(self) -> Iterator[_T_co]: ...

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def readme():
PKG_NAME = os.getenv("RUSTWORKX_PKG_NAME", "rustworkx")
PKG_VERSION = "0.15.0"
PKG_PACKAGES = ["rustworkx", "rustworkx.visualization"]
PKG_INSTALL_REQUIRES = ["numpy>=1.16.0,<2"]
PKG_INSTALL_REQUIRES = ["numpy>=1.16.0,<3"]
RUST_EXTENSIONS = [RustExtension("rustworkx.rustworkx", "Cargo.toml",
binding=Binding.PyO3, debug=rustworkx_debug)]
RUST_OPTS ={"bdist_wheel": {"py_limited_api": "cp38"}}
Expand Down
28 changes: 22 additions & 6 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ use num_bigint::BigUint;
use rustworkx_core::dictmap::*;

use ndarray::prelude::*;
use numpy::{IntoPyArray, PyArrayDescr};
use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError};
use numpy::IntoPyArray;
use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError, PyValueError};
use pyo3::gc::PyVisit;
use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use pyo3::types::PySlice;
use pyo3::PyTraverseError;

Expand Down Expand Up @@ -601,11 +602,26 @@ macro_rules! custom_vec_iter_impl {
fn __array__(
&self,
py: Python,
_dt: Option<&Bound<PyArrayDescr>>,
dtype: Option<PyObject>,
copy: Option<bool>,
) -> PyResult<PyObject> {
// Note: we accept the dtype argument on the signature but
// effictively do nothing with it to let Numpy handle the conversion itself
self.$data.convert_to_pyarray(py)
if copy == Some(false) {
return Err(PyValueError::new_err(
"A copy is needed to return an array from this object.",
));
}
let res = self.$data.convert_to_pyarray(py)?;
Ok(match dtype {
Some(dtype) => {
let numpy_mod = py.import_bound("numpy")?;
let args = (res,);
let kwargs = [("dtype", dtype)].into_py_dict_bound(py);
numpy_mod
.call_method("asarray", args, Some(&kwargs))?
.into()
}
None => res,
})
}

fn __traverse__(&self, vis: PyVisit) -> Result<(), PyTraverseError> {
Expand Down
10 changes: 10 additions & 0 deletions tests/test_custom_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ def test_numpy_conversion(self):
res = self.dag.node_indexes()
np.testing.assert_array_equal(np.asarray(res, dtype=np.uintp), np.array([0, 1]))

def test_numpy_conversion_copy_false(self):
res = self.dag.node_indices()
with self.assertRaises(ValueError):
res.__array__(copy=False)

def test_numpy_conversion_dtype_complex(self):
res = self.dag.node_indices()
array = res.__array__(dtype=complex)
self.assertEqual(np.dtype(complex), array.dtype)


class TestNodesCountMapping(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit a6c9849

Please sign in to comment.