Skip to content

Commit

Permalink
don't accept bytes/unicode objects in sequence casters
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Oct 5, 2023
1 parent 8bd58c7 commit cd2379e
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 7 deletions.
1 change: 1 addition & 0 deletions include/nanobind/nb_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ NB_CORE PyObject *module_new_submodule(PyObject *base, const char *name,

// Try to import a reference-counted ndarray object via DLPack
NB_CORE ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
cleanup_list *cleanup,
bool convert) noexcept;

// Describe a local ndarray object using a DLPack capsule
Expand Down
5 changes: 3 additions & 2 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,14 +540,15 @@ template <typename... Args> struct type_caster<ndarray<Args...>> {
concat_maybe(detail::ndarray_arg<Args>::name...) +
const_name("]"));

bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept {
bool from_python(handle src, uint8_t flags,
cleanup_list *cleanup) noexcept {
constexpr size_t size = (0 + ... + detail::ndarray_arg<Args>::size);
size_t shape[size + 1];
detail::ndarray_req req;
req.shape = shape;
(detail::ndarray_arg<Args>::apply(req), ...);
value = ndarray<Args...>(ndarray_import(
src.ptr(), &req, flags & (uint8_t) cast_flags::convert));
src.ptr(), &req, cleanup, flags & (uint8_t) cast_flags::convert));
return value.is_valid();
}

Expand Down
6 changes: 6 additions & 0 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,12 @@ PyObject **seq_get(PyObject *seq, size_t *size_out, PyObject **temp_out) noexcep
goes wrong, it fails gracefully without reporting errors. Other
overloads will then be tried. */

if (PyUnicode_CheckExact(seq) || PyBytes_CheckExact(seq)) {
*size_out = 0;
*temp_out = nullptr;
return nullptr;
}

#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
if (PyTuple_CheckExact(seq)) {
size = (size_t) PyTuple_GET_SIZE(seq);
Expand Down
12 changes: 7 additions & 5 deletions src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,8 @@ bool ndarray_check(PyObject *o) noexcept {
return result;
}


ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
bool convert) noexcept {
cleanup_list *cleanup, bool convert) noexcept {
object capsule;
bool is_pycapsule = PyCapsule_CheckExact(o);

Expand Down Expand Up @@ -456,10 +455,13 @@ ndarray_handle *ndarray_import(PyObject *o, const ndarray_req *req,
} catch (...) { converted.reset(); }

// Potentially try again recursively
if (!converted.is_valid())
if (!converted.is_valid()) {
return nullptr;
else
return ndarray_import(converted.ptr(), req, false);
} else {
if (cleanup)
cleanup->append(converted.inc_ref().ptr());
return ndarray_import(converted.ptr(), req, nullptr, false);
}
}

if (!pass_dtype || !pass_device || !pass_shape || !pass_order)
Expand Down
1 change: 1 addition & 0 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,7 @@ void keep_alive(PyObject *nurse, PyObject *patient) {
} else {
PyObject *callback =
PyCFunction_New(&keep_alive_callback_def, patient);

PyObject *weakref = PyWeakref_NewRef(nurse, callback);
if (!weakref) {
Py_DECREF(callback);
Expand Down
6 changes: 6 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,4 +439,10 @@ NB_MODULE(test_stl_ext, m) {
return x;
});

m.def("vector_str", [](const std::vector<std::string>& x){
return x;
});
m.def("vector_str", [](std::string& x){
return x;
});
}
4 changes: 4 additions & 0 deletions tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,3 +820,7 @@ def test69_complex_array():
assert t.complex_array_float(np.array([val1_64, -1j, val2_64],dtype=np.complex64)) == [val1_32, (-0-1j), val2_32]
except ImportError:
pass

def test70_vec_char():
assert isinstance(t.vector_str("123"), str)
assert isinstance(t.vector_str(["123", "345"]), list)

0 comments on commit cd2379e

Please sign in to comment.