Skip to content

Commit

Permalink
Fix data race all_type_info_populate in free-threading mode
Browse files Browse the repository at this point in the history
Description:
- fixed data race all_type_info_populate in free-threading mode
- added test

For example, we have 2 threads entering `all_type_info`.
Both enter `all_type_info_get_cache`` function and
there is a first one which inserts a tuple (type, empty_vector) to the map
and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread
after waiting gets (iter_to_key, False).
Inserting thread than will add a weakref and will then call into `all_type_info_populate`.
However, non-inserting thread is not entering `if (ins.second) {` clause and
returns `ins.first->second;`` which is just empty_vector.
Finally, non-inserting thread is failing the check in `allocate_layout`:
```c++
if (n_types == 0) {
    pybind11_fail(
        "instance allocation failed: new instance has no pybind11-registered base types");
}
```
  • Loading branch information
vfdev-5 committed Oct 24, 2024
1 parent f7e14e9 commit 6ab21db
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 10 deletions.
6 changes: 0 additions & 6 deletions include/pybind11/detail/type_caster_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
for (handle parent : reinterpret_borrow<tuple>(t->tp_bases)) {
check.push_back((PyTypeObject *) parent.ptr());
}

auto const &type_dict = get_internals().registered_types_py;
for (size_t i = 0; i < check.size(); i++) {
auto *type = check[i];
Expand Down Expand Up @@ -177,11 +176,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector<type_
*/
inline const std::vector<detail::type_info *> &all_type_info(PyTypeObject *type) {
auto ins = all_type_info_get_cache(type);
if (ins.second) {
// New cache entry: populate it
all_type_info_populate(type, ins.first->second);
}

return ins.first->second;
}

Expand Down
15 changes: 11 additions & 4 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) {
inline std::pair<decltype(internals::registered_types_py)::iterator, bool>
all_type_info_get_cache(PyTypeObject *type) {
auto res = with_internals([type](internals &internals) {
return internals
.registered_types_py
auto ins = internals
.registered_types_py
#ifdef __cpp_lib_unordered_map_try_emplace
.try_emplace(type);
.try_emplace(type);
#else
.emplace(type, std::vector<detail::type_info *>());
.emplace(type, std::vector<detail::type_info *>());
#endif
if (ins.second) {
// In free-threading this method should be called
// under pymutex lock to avoid other threads
// continue running with empty ins.first->second
all_type_info_populate(type, ins.first->second);
}
return ins;
});
if (res.second) {
// New cache entry created; set up a weak reference to automatically remove it if the type
Expand Down
5 changes: 5 additions & 0 deletions tests/pybind11_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) {
for (const auto &initializer : initializers()) {
initializer(m);
}

py::class_<TestContext>(m, "TestContext")
.def(py::init<>(&TestContext::createNewContextForInit))
.def("__enter__", &TestContext::contextEnter)
.def("__exit__", &TestContext::contextExit);
}
20 changes: 20 additions & 0 deletions tests/pybind11_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,23 @@ void ignoreOldStyleInitWarnings(F &&body) {
)",
py::dict(py::arg("body") = py::cpp_function(body)));
}

class TestContext {
public:
TestContext() = delete;
TestContext(const TestContext &) = delete;
TestContext(TestContext &&) = delete;
static TestContext *createNewContextForInit() { return new TestContext("new-context"); }

pybind11::object contextEnter() {
py::object contextObj = py::cast(*this);
return contextObj;
}
void contextExit(const pybind11::object & /*excType*/,
const pybind11::object & /*excVal*/,
const pybind11::object & /*excTb*/) {}

private:
TestContext(std::string context) : context(context) {}
std::string context;
};
43 changes: 43 additions & 0 deletions tests/test_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from unittest import mock

import pytest
Expand Down Expand Up @@ -501,3 +502,45 @@ def test_pr4220_tripped_over_this():
m.Empty0().get_msg()
== "This is really only meant to exercise successful compilation."
)


@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
def test_all_type_info_multithreaded():
# Test data race in all_type_info method in free-threading mode.
# For example, we have 2 threads entering `all_type_info`.
# Both enter `all_type_info_get_cache`` function and
# there is a first one which inserts a tuple (type, empty_vector) to the map
# and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread
# after waiting gets (iter_to_key, False).
# Inserting thread than will add a weakref and will then call into `all_type_info_populate`.
# However, non-inserting thread is not entering `if (ins.second) {` clause and
# returns `ins.first->second;`` which is just empty_vector.
# Finally, non-inserting thread is failing the check in `allocate_layout`:
# if (n_types == 0) {
# pybind11_fail(
# "instance allocation failed: new instance has no pybind11-registered base types");
# }
import threading

from pybind11_tests import TestContext

class Context(TestContext):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

num_runs = 4
num_threads = 5
barrier = threading.Barrier(num_threads)

def func():
barrier.wait()
with Context():
pass

for _ in range(num_runs):
threads = [threading.Thread(target=func) for _ in range(num_threads)]
for thread in threads:
thread.start()

for thread in threads:
thread.join()

0 comments on commit 6ab21db

Please sign in to comment.