Skip to content

Commit

Permalink
MAINT: Include numpy._core imports (google#4857)
Browse files Browse the repository at this point in the history
* MAINT: Include numpy._core imports

* style: pre-commit fixes

* Apply review comments

* style: pre-commit fixes

* Add no-inline attribute

* Select submodule name based on numpy version

* style: pre-commit fixes

* Update pre-commit check

* Add error_already_set and simplify if statement

* Update .pre-commit-config.yaml

Co-authored-by: Ralf W. Grosse-Kunstleve <rwgkio@gmail.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ralf W. Grosse-Kunstleve <rwgkio@gmail.com>
  • Loading branch information
3 people authored Sep 27, 2023
1 parent f468b07 commit 0a756c0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ repos:
- id: disallow-caps
name: Disallow improper capitalization
language: pygrep
entry: PyBind|Numpy|Cmake|CCache|PyTest
entry: PyBind|\bNumpy\b|Cmake|CCache|PyTest
exclude: ^\.pre-commit-config.yaml$

# PyLint has native support - not always usable, but works for us
Expand Down
27 changes: 21 additions & 6 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ inline numpy_internals &get_numpy_internals() {
return *ptr;
}

PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
module_ numpy = module_::import("numpy");
str version_string = numpy.attr("__version__");

module_ numpy_lib = module_::import("numpy.lib");
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
int major_version = numpy_version.attr("major").cast<int>();

/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
became a private module. */
std::string numpy_core_path = major_version >= 2 ? "numpy._core" : "numpy.core";
return module_::import((numpy_core_path + "." + submodule_name).c_str());
}

template <typename T>
struct same_size {
template <typename U>
Expand Down Expand Up @@ -263,9 +277,13 @@ struct npy_api {
};

static npy_api lookup() {
module_ m = module_::import("numpy.core.multiarray");
module_ m = detail::import_numpy_core_submodule("multiarray");
auto c = m.attr("_ARRAY_API");
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
if (api_ptr == nullptr) {
raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
throw error_already_set();
}
npy_api api;
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
Expand Down Expand Up @@ -626,11 +644,8 @@ class dtype : public object {

private:
static object _dtype_from_pep3118() {
static PyObject *obj = module_::import("numpy.core._internal")
.attr("_dtype_from_pep3118")
.cast<object>()
.release()
.ptr();
module_ m = detail::import_numpy_core_submodule("_internal");
static PyObject *obj = m.attr("_dtype_from_pep3118").cast<object>().release().ptr();
return reinterpret_borrow<object>(obj);
}

Expand Down

0 comments on commit 0a756c0

Please sign in to comment.