Skip to content

Commit

Permalink
Merge branch 'smart_holder' into pybind11k_merge_sh
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralf W. Grosse-Kunstleve committed Aug 25, 2024
2 parents 243ae9f + 04d9f84 commit 43bdd56
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 40 deletions.
50 changes: 40 additions & 10 deletions include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ struct copyable_holder_caster<
pybind11_fail("Passing `std::shared_ptr<T> *` from Python to C++ is not supported "
"(inherently unsafe).");
}
return std::addressof(shared_ptr_holder);
return std::addressof(shared_ptr_storage);
}

explicit operator std::shared_ptr<type> &() {
Expand All @@ -1011,9 +1011,9 @@ struct copyable_holder_caster<
throw cast_error("Unowned pointer from a void pointer capsule cannot be converted "
"to a std::shared_ptr.");
}
shared_ptr_holder = sh_load_helper.load_as_shared_ptr(value);
shared_ptr_storage = sh_load_helper.load_as_shared_ptr(value);
}
return shared_ptr_holder;
return shared_ptr_storage;
}

static handle
Expand Down Expand Up @@ -1059,7 +1059,7 @@ struct copyable_holder_caster<
}
if (v_h.holder_constructed()) {
value = v_h.value_ptr();
shared_ptr_holder = v_h.template holder<std::shared_ptr<type>>();
shared_ptr_storage = v_h.template holder<std::shared_ptr<type>>();
return;
}
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
Expand Down Expand Up @@ -1088,8 +1088,8 @@ struct copyable_holder_caster<
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
sh_load_helper.loaded_v_h = sub_caster.sh_load_helper.loaded_v_h;
} else {
shared_ptr_holder
= std::shared_ptr<type>(sub_caster.shared_ptr_holder, (type *) value);
shared_ptr_storage
= std::shared_ptr<type>(sub_caster.shared_ptr_storage, (type *) value);
}
return true;
}
Expand All @@ -1113,8 +1113,8 @@ struct copyable_holder_caster<
return false;
}

std::shared_ptr<type> shared_ptr_holder;
smart_holder_type_caster_support::load_helper<remove_cv_t<type>> sh_load_helper; // Const2Mutbl
std::shared_ptr<type> shared_ptr_storage;
bool from_direct_conversion = false;
bool from_as_void_ptr_capsule = false;
};
Expand Down Expand Up @@ -1191,7 +1191,7 @@ struct move_only_holder_caster<
policy = return_value_policy::reference_internal;
}
if (policy != return_value_policy::reference_internal) {
throw cast_error("Invalid return_value_policy for unique_ptr&");
throw cast_error("Invalid return_value_policy for const unique_ptr&");
}
return type_caster_base<type>::cast(src.get(), policy, parent);
}
Expand All @@ -1218,8 +1218,14 @@ struct move_only_holder_caster<
+ clean_type_id(typeinfo->cpptype->name()) + ")");
}

template <typename>
using cast_op_type = std::unique_ptr<type, deleter>;
template <typename T_>
using cast_op_type
= conditional_t<std::is_same<typename std::remove_volatile<T_>::type,
const std::unique_ptr<type, deleter> &>::value
|| std::is_same<typename std::remove_volatile<T_>::type,
const std::unique_ptr<const type, deleter> &>::value,
const std::unique_ptr<type, deleter> &,
std::unique_ptr<type, deleter>>;

explicit operator std::unique_ptr<type, deleter>() {
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
Expand All @@ -1236,6 +1242,28 @@ struct move_only_holder_caster<
pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__));
}

explicit operator const std::unique_ptr<type, deleter> &() {
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
// Get shared_ptr to ensure that the Python object is not disowned elsewhere.
shared_ptr_storage = sh_load_helper.load_as_shared_ptr(value);
// Build a temporary unique_ptr that is meant to never expire.
unique_ptr_storage = std::shared_ptr<std::unique_ptr<type, deleter>>(
new std::unique_ptr<type, deleter>{
sh_load_helper.template load_as_const_unique_ptr<deleter>(
shared_ptr_storage.get())},
[](std::unique_ptr<type, deleter> *ptr) {
if (!ptr) {
pybind11_fail("FATAL: `const std::unique_ptr<T, D> &` was disowned "
"(EXPECT UNDEFINED BEHAVIOR).");
}
(void) ptr->release();
delete ptr;
});
return *unique_ptr_storage;
}
pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__));
}

bool try_implicit_casts(handle src, bool convert) {
for (auto &cast : typeinfo->implicit_casts) {
move_only_holder_caster sub_caster(*cast.first);
Expand Down Expand Up @@ -1270,6 +1298,8 @@ struct move_only_holder_caster<
}

smart_holder_type_caster_support::load_helper<remove_cv_t<type>> sh_load_helper; // Const2Mutbl
std::shared_ptr<type> shared_ptr_storage; // Serves as a pseudo lock.
std::shared_ptr<std::unique_ptr<type, deleter>> unique_ptr_storage;
bool from_direct_conversion = false;
bool from_as_void_ptr_capsule = false;
};
Expand Down
18 changes: 18 additions & 0 deletions include/pybind11/detail/struct_smart_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,24 @@ struct smart_holder {
vptr_del_ptr->armed_flag = armed_flag;
}

// Caller is responsible for precondition: ensure_compatible_rtti_uqp_del<T, D>() must succeed.
template <typename T, typename D>
std::unique_ptr<D> extract_deleter(const char *context) const {
const auto *gd = std::get_deleter<guarded_delete>(vptr);
if (gd && gd->use_del_fun) {
const auto &custom_deleter_ptr = gd->del_fun.template target<custom_deleter<T, D>>();
if (custom_deleter_ptr == nullptr) {
throw std::runtime_error(
std::string("smart_holder::extract_deleter() precondition failure (") + context
+ ").");
}
static_assert(std::is_copy_constructible<D>::value,
"Required for compatibility with smart_holder functionality.");
return std::unique_ptr<D>(new D(custom_deleter_ptr->deleter));
}
return nullptr;
}

static smart_holder from_raw_ptr_unowned(void *raw_ptr) {
smart_holder hld;
hld.vptr.reset(raw_ptr, [](void *) {});
Expand Down
30 changes: 14 additions & 16 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -797,22 +797,7 @@ struct load_helper : value_and_holder_helper {
"instance cannot safely be transferred to C++.");
}

// Temporary variable to store the extracted deleter in.
std::unique_ptr<D> extracted_deleter;

auto *gd = std::get_deleter<pybindit::memory::guarded_delete>(holder().vptr);
if (gd && gd->use_del_fun) { // Note the ensure_compatible_rtti_uqp_del<T, D>() call above.
// In struct_smart_holder, a custom deleter is always stored in a guarded delete.
// The guarded delete's std::function<void(void*)> actually points at the
// custom_deleter type, so we can verify it is of the custom deleter type and
// finally extract its deleter.
using custom_deleter_D = pybindit::memory::custom_deleter<T, D>;
const auto &custom_deleter_ptr = gd->del_fun.template target<custom_deleter_D>();
assert(custom_deleter_ptr != nullptr);
// Now that we have confirmed the type of the deleter matches the desired return
// value we can extract the function.
extracted_deleter = std::unique_ptr<D>(new D(std::move(custom_deleter_ptr->deleter)));
}
std::unique_ptr<D> extracted_deleter = holder().template extract_deleter<T, D>(context);

// Critical transfer-of-ownership section. This must stay together.
if (self_life_support != nullptr) {
Expand All @@ -832,6 +817,19 @@ struct load_helper : value_and_holder_helper {

return result;
}

// This assumes load_as_shared_ptr succeeded(), and the returned shared_ptr is still alive.
// The returned unique_ptr is meant to never expire (the behavior is undefined otherwise).
template <typename D>
std::unique_ptr<T, D>
load_as_const_unique_ptr(T *raw_type_ptr, const char *context = "load_as_const_unique_ptr") {
if (!have_holder()) {
return unique_with_deleter<T, D>(nullptr, std::unique_ptr<D>());
}
holder().template ensure_compatible_rtti_uqp_del<T, D>(context);
return unique_with_deleter<T, D>(
raw_type_ptr, std::move(holder().template extract_deleter<T, D>(context)));
}
};

PYBIND11_NAMESPACE_END(smart_holder_type_caster_support)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_class_sh_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }

std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }

std::string pass_unique_ptr_cref(const std::unique_ptr<atyp> &obj) { return obj->mtxt; }

const std::unique_ptr<atyp> &rtrn_unique_ptr_cref(const std::string &mtxt) {
static std::unique_ptr<atyp> obj{new atyp{"static_ctor_arg"}};
if (!mtxt.empty()) {
obj->mtxt = mtxt;
}
return obj;
}

const std::unique_ptr<atyp> &unique_ptr_cref_roundtrip(const std::unique_ptr<atyp> &obj) {
return obj;
}
Expand Down Expand Up @@ -217,6 +228,9 @@ TEST_SUBMODULE(class_sh_basic, m) {
m.def("get_ptr", get_ptr); // pass_cref

m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp

m.def("pass_unique_ptr_cref", pass_unique_ptr_cref);
m.def("rtrn_unique_ptr_cref", rtrn_unique_ptr_cref);
m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip);

py::classh<SharedPtrStash>(m, "SharedPtrStash")
Expand Down
34 changes: 23 additions & 11 deletions tests/test_class_sh_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,31 @@ def test_unique_ptr_roundtrip(num_round_trips=1000):
id_orig = id_rtrn


# This currently fails, because a unique_ptr is always loaded by value
# due to pybind11/detail/smart_holder_type_casters.h:689
# I think, we need to provide more cast operators.
@pytest.mark.skip()
def test_unique_ptr_cref_roundtrip():
def test_pass_unique_ptr_cref():
obj = m.atyp("ctor_arg")
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj))
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.pass_unique_ptr_cref(obj))
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj))


def test_rtrn_unique_ptr_cref():
obj0 = m.rtrn_unique_ptr_cref("")
assert m.get_mtxt(obj0) == "static_ctor_arg"
obj1 = m.rtrn_unique_ptr_cref("passed_mtxt_1")
assert m.get_mtxt(obj1) == "passed_mtxt_1"
assert m.get_mtxt(obj0) == "passed_mtxt_1"
assert obj0 is obj1


def test_unique_ptr_cref_roundtrip(num_round_trips=1000):
# Multiple roundtrips to stress-test implementation.
orig = m.atyp("passenger")
id_orig = id(orig)
mtxt_orig = m.get_mtxt(orig)

recycled = m.unique_ptr_cref_roundtrip(orig)
assert m.get_mtxt(orig) == mtxt_orig
assert m.get_mtxt(recycled) == mtxt_orig
assert id(recycled) == id_orig
recycled = orig
for _ in range(num_round_trips):
recycled = m.unique_ptr_cref_roundtrip(recycled)
assert recycled is orig
assert m.get_mtxt(recycled) == mtxt_orig


@pytest.mark.parametrize(
Expand Down
10 changes: 8 additions & 2 deletions tests/test_class_sh_trampoline_shared_from_this.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ long pass_shared_ptr(const std::shared_ptr<Sft> &obj) {
return sft.use_count();
}

void pass_unique_ptr(const std::unique_ptr<Sft> &) {}
std::string pass_unique_ptr_cref(const std::unique_ptr<Sft> &obj) {
return obj ? obj->history : "<NULLPTR>";
}
void pass_unique_ptr_rref(std::unique_ptr<Sft> &&) {
throw std::runtime_error("Expected to not be reached.");
}

Sft *make_pure_cpp_sft_raw_ptr(const std::string &history_seed) { return new Sft{history_seed}; }

Expand Down Expand Up @@ -135,7 +140,8 @@ TEST_SUBMODULE(class_sh_trampoline_shared_from_this, m) {

m.def("use_count", use_count);
m.def("pass_shared_ptr", pass_shared_ptr);
m.def("pass_unique_ptr", pass_unique_ptr);
m.def("pass_unique_ptr_cref", pass_unique_ptr_cref);
m.def("pass_unique_ptr_rref", pass_unique_ptr_rref);
m.def("make_pure_cpp_sft_raw_ptr", make_pure_cpp_sft_raw_ptr);
m.def("make_pure_cpp_sft_unq_ptr", make_pure_cpp_sft_unq_ptr);
m.def("make_pure_cpp_sft_shd_ptr", make_pure_cpp_sft_shd_ptr);
Expand Down
5 changes: 4 additions & 1 deletion tests/test_class_sh_trampoline_shared_from_this.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,14 @@ def test_pass_released_shared_ptr_as_unique_ptr():
obj = PySft("PySft")
stash1 = m.SftSharedPtrStash(1)
stash1.Add(obj) # Releases shared_ptr to C++.
assert m.pass_unique_ptr_cref(obj) == "PySft_Stash1Add"
assert obj.history == "PySft_Stash1Add"
with pytest.raises(ValueError) as exc_info:
m.pass_unique_ptr(obj)
m.pass_unique_ptr_rref(obj)
assert str(exc_info.value) == (
"Python instance is currently owned by a std::shared_ptr."
)
assert obj.history == "PySft_Stash1Add"


@pytest.mark.parametrize(
Expand Down

0 comments on commit 43bdd56

Please sign in to comment.