From 6ab21dbaacce4a8b3efc5df73350df10297b107c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 23 Oct 2024 21:23:17 +0200 Subject: [PATCH] Fix data race all_type_info_populate in free-threading mode Description: - fixed data race all_type_info_populate in free-threading mode - added test For example, we have 2 threads entering `all_type_info`. Both enter `all_type_info_get_cache`` function and there is a first one which inserts a tuple (type, empty_vector) to the map and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread after waiting gets (iter_to_key, False). Inserting thread than will add a weakref and will then call into `all_type_info_populate`. However, non-inserting thread is not entering `if (ins.second) {` clause and returns `ins.first->second;`` which is just empty_vector. Finally, non-inserting thread is failing the check in `allocate_layout`: ```c++ if (n_types == 0) { pybind11_fail( "instance allocation failed: new instance has no pybind11-registered base types"); } ``` --- include/pybind11/detail/type_caster_base.h | 6 --- include/pybind11/pybind11.h | 15 ++++++-- tests/pybind11_tests.cpp | 5 +++ tests/pybind11_tests.h | 20 ++++++++++ tests/test_class.py | 43 ++++++++++++++++++++++ 5 files changed, 79 insertions(+), 10 deletions(-) diff --git a/include/pybind11/detail/type_caster_base.h b/include/pybind11/detail/type_caster_base.h index e7b94aff2a..e7d4c6cdaf 100644 --- a/include/pybind11/detail/type_caster_base.h +++ b/include/pybind11/detail/type_caster_base.h @@ -117,7 +117,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector(t->tp_bases)) { check.push_back((PyTypeObject *) parent.ptr()); } - auto const &type_dict = get_internals().registered_types_py; for (size_t i = 0; i < check.size(); i++) { auto *type = check[i]; @@ -177,11 +176,6 @@ PYBIND11_NOINLINE void all_type_info_populate(PyTypeObject *t, std::vector &all_type_info(PyTypeObject *type) { auto ins = all_type_info_get_cache(type); - if (ins.second) { - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); - } - return ins.first->second; } diff --git a/include/pybind11/pybind11.h b/include/pybind11/pybind11.h index 2527d25faf..d54108052d 100644 --- a/include/pybind11/pybind11.h +++ b/include/pybind11/pybind11.h @@ -2326,13 +2326,20 @@ keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { inline std::pair all_type_info_get_cache(PyTypeObject *type) { auto res = with_internals([type](internals &internals) { - return internals - .registered_types_py + auto ins = internals + .registered_types_py #ifdef __cpp_lib_unordered_map_try_emplace - .try_emplace(type); + .try_emplace(type); #else - .emplace(type, std::vector()); + .emplace(type, std::vector()); #endif + if (ins.second) { + // In free-threading this method should be called + // under pymutex lock to avoid other threads + // continue running with empty ins.first->second + all_type_info_populate(type, ins.first->second); + } + return ins; }); if (res.second) { // New cache entry created; set up a weak reference to automatically remove it if the type diff --git a/tests/pybind11_tests.cpp b/tests/pybind11_tests.cpp index 3d2d84e77a..818d53a548 100644 --- a/tests/pybind11_tests.cpp +++ b/tests/pybind11_tests.cpp @@ -128,4 +128,9 @@ PYBIND11_MODULE(pybind11_tests, m, py::mod_gil_not_used()) { for (const auto &initializer : initializers()) { initializer(m); } + + py::class_(m, "TestContext") + .def(py::init<>(&TestContext::createNewContextForInit)) + .def("__enter__", &TestContext::contextEnter) + .def("__exit__", &TestContext::contextExit); } diff --git a/tests/pybind11_tests.h b/tests/pybind11_tests.h index 7be58feb6c..cfcc024fc7 100644 --- a/tests/pybind11_tests.h +++ b/tests/pybind11_tests.h @@ -96,3 +96,23 @@ void ignoreOldStyleInitWarnings(F &&body) { )", py::dict(py::arg("body") = py::cpp_function(body))); } + +class TestContext { +public: + TestContext() = delete; + TestContext(const TestContext &) = delete; + TestContext(TestContext &&) = delete; + static TestContext *createNewContextForInit() { return new TestContext("new-context"); } + + pybind11::object contextEnter() { + py::object contextObj = py::cast(*this); + return contextObj; + } + void contextExit(const pybind11::object & /*excType*/, + const pybind11::object & /*excVal*/, + const pybind11::object & /*excTb*/) {} + +private: + TestContext(std::string context) : context(context) {} + std::string context; +}; \ No newline at end of file diff --git a/tests/test_class.py b/tests/test_class.py index 470d2a3269..5809056f02 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from unittest import mock import pytest @@ -501,3 +502,45 @@ def test_pr4220_tripped_over_this(): m.Empty0().get_msg() == "This is really only meant to exercise successful compilation." ) + + +@pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads") +def test_all_type_info_multithreaded(): + # Test data race in all_type_info method in free-threading mode. + # For example, we have 2 threads entering `all_type_info`. + # Both enter `all_type_info_get_cache`` function and + # there is a first one which inserts a tuple (type, empty_vector) to the map + # and second is waiting. Inserting thread gets the (iter_to_key, True) and non-inserting thread + # after waiting gets (iter_to_key, False). + # Inserting thread than will add a weakref and will then call into `all_type_info_populate`. + # However, non-inserting thread is not entering `if (ins.second) {` clause and + # returns `ins.first->second;`` which is just empty_vector. + # Finally, non-inserting thread is failing the check in `allocate_layout`: + # if (n_types == 0) { + # pybind11_fail( + # "instance allocation failed: new instance has no pybind11-registered base types"); + # } + import threading + + from pybind11_tests import TestContext + + class Context(TestContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + num_runs = 4 + num_threads = 5 + barrier = threading.Barrier(num_threads) + + def func(): + barrier.wait() + with Context(): + pass + + for _ in range(num_runs): + threads = [threading.Thread(target=func) for _ in range(num_threads)] + for thread in threads: + thread.start() + + for thread in threads: + thread.join()