Skip to content

Commit

Permalink
Attach python lifetime to shared_ptr passed to C++
Browse files Browse the repository at this point in the history
- Reference cycles are possible as a result, but shared_ptr is already susceptible to this in C++
  • Loading branch information
virtuald committed Feb 1, 2021
1 parent 721834b commit 0612e51
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 1 deletion.
33 changes: 32 additions & 1 deletion include/pybind11/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,37 @@ struct holder_helper {
static auto get(const T &p) -> decltype(p.get()) { return p.get(); }
};

/// Another helper class for holders that helps construct derivative holders from
/// the original holder
template <typename T>
struct holder_retriever {
static auto get_derivative_holder(const value_and_holder &v_h) -> decltype(v_h.template holder<T>()) {
return v_h.template holder<T>();
}
};

template <typename T>
struct holder_retriever<std::shared_ptr<T>> {
struct shared_ptr_deleter {
object ref;
void operator()(T *) {}
};

static auto get_derivative_holder(const value_and_holder &v_h) -> std::shared_ptr<T> {
// The shared_ptr is always given to C++ code, so construct a new shared_ptr
// that is given a custom deleter. The custom deleter increments the python
// reference count to bind the python instance lifetime with the lifetime
// of the shared_ptr.
//
// This enables things like passing the last python reference of a subclass to a
// C++ function without the python reference dying.
//
// Reference cycles will cause a leak, but this is a limitation of shared_ptr
return std::shared_ptr<T>((T*)v_h.value_ptr(),
shared_ptr_deleter{reinterpret_borrow<object>((PyObject*)v_h.inst)});
}
};

/// Type caster for holder types like std::shared_ptr, etc.
/// The SFINAE hook is provided to help work around the current lack of support
/// for smart-pointer interoperability. Please consider it an implementation
Expand Down Expand Up @@ -1566,7 +1597,7 @@ struct copyable_holder_caster : public type_caster_base<type> {
bool load_value(value_and_holder &&v_h) {
if (v_h.holder_constructed()) {
value = v_h.value_ptr();
holder = v_h.template holder<holder_type>();
holder = holder_retriever<holder_type>::get_derivative_holder(v_h);
return true;
} else {
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
Expand Down
32 changes: 32 additions & 0 deletions tests/test_smart_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,4 +397,36 @@ TEST_SUBMODULE(smart_ptr, m) {
list.append(py::cast(e));
return list;
});

// For testing whether a python subclass of a C++ object dies when the
// last python reference is lost
struct SpBase {
// returns true if the base virtual function is called
virtual bool is_base_used() { return true; }

SpBase() = default;
SpBase(const SpBase&) = delete;
virtual ~SpBase() = default;
};

struct PySpBase : SpBase {
bool is_base_used() override { PYBIND11_OVERRIDE(bool, SpBase, is_base_used); }
};

struct SpBaseTester {
std::shared_ptr<SpBase> get_object() { return m_obj; }
void set_object(std::shared_ptr<SpBase> obj) { m_obj = obj; }
bool is_base_used() { return m_obj->is_base_used(); }
std::shared_ptr<SpBase> m_obj;
};

py::class_<SpBase, std::shared_ptr<SpBase>, PySpBase>(m, "SpBase")
.def(py::init<>())
.def("is_base_used", &SpBase::is_base_used);

py::class_<SpBaseTester>(m, "SpBaseTester")
.def(py::init<>())
.def("get_object", &SpBaseTester::get_object)
.def("set_object", &SpBaseTester::set_object)
.def("is_base_used", &SpBaseTester::is_base_used);
}
46 changes: 46 additions & 0 deletions tests/test_smart_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_smart_ptr(capture):
m.print_myobject3_4(o)
assert capture == "MyObject3[{i}]\n".format(i=i) * 4

pytest.gc_collect()
cstats = ConstructorStats.get(m.MyObject3)
assert cstats.alive() == 1
o = None
Expand Down Expand Up @@ -194,6 +195,7 @@ def test_shared_ptr_and_references():
assert s.set_holder(holder_copy)

del ref, copy, holder_ref, holder_copy, s
pytest.gc_collect()
assert stats.alive() == 0


Expand Down Expand Up @@ -316,3 +318,47 @@ def test_shared_ptr_gc():
pytest.gc_collect()
for i, v in enumerate(el.get()):
assert i == v.value()


def test_shared_ptr_cpp_arg():
import weakref

class PyChild(m.SpBase):
def is_base_used(self):
return False

tester = m.SpBaseTester()

obj = PyChild()
objref = weakref.ref(obj)

tester.set_object(obj)
del obj
pytest.gc_collect()

# python reference is still around since C++ has it now
assert objref() is not None
assert tester.is_base_used() is False
assert tester.get_object() is objref()


def test_shared_ptr_arg_identity():
import weakref

tester = m.SpBaseTester()

obj = m.SpBase()
objref = weakref.ref(obj)

tester.set_object(obj)
del obj
pytest.gc_collect()

# python reference is still around since C++ has it
assert objref() is not None
assert tester.get_object() is objref()

# python reference disappears once the C++ object releases it
tester.set_object(None)
pytest.gc_collect()
assert objref() is None

0 comments on commit 0612e51

Please sign in to comment.