Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added free-threading CPython mode support in MLIR Python bindings #107103

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vfdev-5
Copy link

@vfdev-5 vfdev-5 commented Sep 3, 2024

Related to #105522

Description:

  • Added free-threading CPython mode support in MLIR Python bindings via pybind11
  • Updated python requirements

Context:

Copy link

github-actions bot commented Sep 3, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@vfdev-5 vfdev-5 changed the title [WIP] Added free-threading CPython mode support in Python bindings [WIP] Added free-threading CPython mode support in MLIR Python bindings Sep 3, 2024
@vfdev-5 vfdev-5 marked this pull request as ready for review September 18, 2024 13:01
@llvmbot llvmbot added mlir:python MLIR Python bindings mlir labels Sep 18, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 18, 2024

@llvm/pr-subscribers-mlir

Author: vfdev (vfdev-5)

Changes

Related to #105522

Description:

  • Added free-threading CPython mode support in MLIR Python bindings via pybind11
  • Updated python requirements

Context:


Full diff: https://github.com/llvm/llvm-project/pull/107103.diff

17 Files Affected:

  • (modified) mlir/lib/Bindings/Python/AsyncPasses.cpp (+3-1)
  • (modified) mlir/lib/Bindings/Python/DialectGPU.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectLinalg.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectNVGPU.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectPDL.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectQuant.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/ExecutionEngineModule.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/GPUPasses.cpp (+3-1)
  • (modified) mlir/lib/Bindings/Python/LinalgPasses.cpp (+3-1)
  • (modified) mlir/lib/Bindings/Python/MainModule.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/RegisterEverything.cpp (+1-1)
  • (modified) mlir/lib/Bindings/Python/SparseTensorPasses.cpp (+3-1)
  • (modified) mlir/lib/Bindings/Python/TransformInterpreter.cpp (+1-1)
  • (modified) mlir/python/requirements.txt (+5-3)
diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp
index b611a758dbbb37..d34a164b6e30c2 100644
--- a/mlir/lib/Bindings/Python/AsyncPasses.cpp
+++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp
@@ -11,11 +11,13 @@
 #include <pybind11/detail/common.h>
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirAsyncPasses, m) {
+PYBIND11_MODULE(_mlirAsyncPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Async Dialect Passes";
 
   // Register all Async passes on load.
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 560a54bcd15919..5acfad007c3055 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -23,7 +23,7 @@ using namespace mlir::python::adaptors;
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirDialectsGPU, m) {
+PYBIND11_MODULE(_mlirDialectsGPU, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR GPU Dialect";
   //===-------------------------------------------------------------------===//
   // AsyncTokenType
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 42a4c8c0793ba8..2af133a061eb4b 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -134,7 +134,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
       });
 }
 
-PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+PYBIND11_MODULE(_mlirDialectsLLVM, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR LLVM Dialect";
 
   populateDialectLLVMSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 2e54ebeb61fb10..118c4ab3f2f573 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -21,7 +21,7 @@ static void populateDialectLinalgSubmodule(py::module m) {
       "op.");
 }
 
-PYBIND11_MODULE(_mlirDialectsLinalg, m) {
+PYBIND11_MODULE(_mlirDialectsLinalg, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Linalg dialect.";
 
   populateDialectLinalgSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 754e0a75b0abc7..4c962c403082cb 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -34,7 +34,7 @@ static void populateDialectNVGPUSubmodule(const pybind11::module &m) {
       py::arg("ctx") = py::none());
 }
 
-PYBIND11_MODULE(_mlirDialectsNVGPU, m) {
+PYBIND11_MODULE(_mlirDialectsNVGPU, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR NVGPU dialect.";
 
   populateDialectNVGPUSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 8d3f9a7ab1d6ac..e8542d5e777a65 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -100,7 +100,7 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
       py::arg("context") = py::none());
 }
 
-PYBIND11_MODULE(_mlirDialectsPDL, m) {
+PYBIND11_MODULE(_mlirDialectsPDL, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR PDL dialect.";
   populateDialectPDLSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index af9cdc7bdd2d89..fc6ef9f46ce8e5 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -307,7 +307,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
   });
 }
 
-PYBIND11_MODULE(_mlirDialectsQuant, m) {
+PYBIND11_MODULE(_mlirDialectsQuant, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Quantization dialect";
 
   populateDialectQuantSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 584981cfe99bf1..d4f35859fdcf1a 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -142,7 +142,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       });
 }
 
-PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
+PYBIND11_MODULE(_mlirDialectsSparseTensor, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR SparseTensor dialect.";
   populateDialectSparseTensorSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 6b57e652aa9d8b..df665dd66bdc28 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -117,7 +117,7 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get the type this ParamType is associated with.");
 }
 
-PYBIND11_MODULE(_mlirDialectsTransform, m) {
+PYBIND11_MODULE(_mlirDialectsTransform, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Transform dialect.";
   populateDialectTransformSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index b3df30583fc963..ddd81d1e7d592e 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -64,7 +64,7 @@ class PyExecutionEngine {
 } // namespace
 
 /// Create the `mlir.execution_engine` module here.
-PYBIND11_MODULE(_mlirExecutionEngine, m) {
+PYBIND11_MODULE(_mlirExecutionEngine, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Execution Engine";
 
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp
index e276a3ce3a56a0..bfc43e99946bb9 100644
--- a/mlir/lib/Bindings/Python/GPUPasses.cpp
+++ b/mlir/lib/Bindings/Python/GPUPasses.cpp
@@ -11,11 +11,13 @@
 #include <pybind11/detail/common.h>
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirGPUPasses, m) {
+PYBIND11_MODULE(_mlirGPUPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR GPU Dialect Passes";
 
   // Register all GPU passes on load.
diff --git a/mlir/lib/Bindings/Python/LinalgPasses.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp
index 3f230207a42114..e3d8f237e2bab3 100644
--- a/mlir/lib/Bindings/Python/LinalgPasses.cpp
+++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp
@@ -10,11 +10,13 @@
 
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirLinalgPasses, m) {
+PYBIND11_MODULE(_mlirLinalgPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Linalg Dialect Passes";
 
   // Register all Linalg passes on load.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 8da1ab16a4514b..de713e7031a01e 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -22,7 +22,7 @@ using namespace mlir::python;
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlir, m) {
+PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Python Native Extension";
 
   py::class_<PyGlobals>(m, "_Globals", py::module_local())
diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp
index 6b2f6b0a6a3b86..5c5c6e32102712 100644
--- a/mlir/lib/Bindings/Python/RegisterEverything.cpp
+++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp
@@ -9,7 +9,7 @@
 #include "mlir-c/RegisterEverything.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
-PYBIND11_MODULE(_mlirRegisterEverything, m) {
+PYBIND11_MODULE(_mlirRegisterEverything, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration";
 
   m.def("register_dialects", [](MlirDialectRegistry registry) {
diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
index 2a8e2b802df9c4..1bbdf2f3ccf4e5 100644
--- a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
+++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
@@ -10,11 +10,13 @@
 
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirSparseTensorPasses, m) {
+PYBIND11_MODULE(_mlirSparseTensorPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR SparseTensor Dialect Passes";
 
   // Register all SparseTensor passes on load.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f6b4532b1b6be4..93ab447d52bec1 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -99,7 +99,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
       py::arg("target"), py::arg("other"));
 }
 
-PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+PYBIND11_MODULE(_mlirTransformInterpreter, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Transform dialect interpreter functionality.";
   populateTransformInterpreterSubmodule(m);
 }
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index d1b5418cca5b23..49b8471c6b771c 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,6 @@
-numpy>=1.19.5, <=1.26
-pybind11>=2.9.0, <=2.10.3
+numpy>=1.19.5, <3.0
+# pybind11>=2.14.0, <2.15.0
+# Temporarily set pybind11 version to master waiting the next release to 2.13.6
+pybind11 @ git+https://github.com/pybind/pybind11@master
 PyYAML>=5.4.0, <=6.0.1
-ml_dtypes>=0.1.0, <=0.4.0   # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.5.0, <=0.6.0   # provides several NumPy dtype extensions, including the bf16

@vfdev-5 vfdev-5 changed the title [WIP] Added free-threading CPython mode support in MLIR Python bindings Added free-threading CPython mode support in MLIR Python bindings Sep 18, 2024
@makslevental
Copy link
Contributor

I hit "run tests" but I don't think this is going to work - there are places where we do assume GIL. I don't remember exactly (I believe it's in some of the custom stuff in pybind_adaptors). I can take a closer look soon.

@vfdev-5
Copy link
Author

vfdev-5 commented Sep 19, 2024

@makslevental thanks for the feedback. At the moment, there is no CI job using python 3.13.0rc1 with free-threading enabled, we may want to add a CI job for that...

I run existing python tests on python 3.13 with free-threading and all of them are passing using pybind11 from source (otherwise, there is a deadlock (pybind/pybind11#5346)).

I don't remember exactly (I believe it's in some of the custom stuff in pybind_adaptors). I can take a closer look soon.

I'll check from my side as well.

Here are few docs on free-threading support in case:

@vfdev-5 vfdev-5 marked this pull request as draft September 22, 2024 19:50
@vfdev-5
Copy link
Author

vfdev-5 commented Sep 22, 2024

Set this PR as draft as found non-protected for multi-threading C++ code using liveContexts in PyMlirContext which is failing in my local multi-threaded tests.

@stellaraccident
Copy link
Contributor

Thanks for this. I've been burning through free threading deps in other libraries but hadn't come to this one yet.

Given that this is all very bleeding edge, it may be a little bit before we have proper CI for it upstream, but I'm happy to accept patches that improve free threaded compatibility in the meantime, especially if you have a test rig.

@stellaraccident
Copy link
Contributor

stellaraccident commented Sep 22, 2024

Maks is right: three are a couple of places where we use mutable global state and we should come up with a convention to protect those with a global mutex. I know there is an idiom for this in CPython itself, but is there a common thing done for pybind/extensions yet? (Basically where it no ops in gil builds but uses a scoped mutex in free threaded builds)

Let's verify that before marking the modules free threaded compatible.

@vfdev-5
Copy link
Author

vfdev-5 commented Sep 22, 2024

@stellaraccident thanks, i have few questions though about the current python bindings tests using lit.

My local tests are rather simple and I would like to improve that:

import typer

import threading
import concurrent.futures

import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, F64Type, InsertionPoint


def mt_run(fn, num_threads, args=(), kwargs={}):
    barrier = threading.Barrier(num_threads)

    def closure():
        barrier.wait()
        return fn(*args, **kwargs)

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=num_threads
    ) as executor:
        futures = []
        for _ in range(num_threads):
            futures.append(executor.submit(closure))
        # We should call future.result() to re-raise an exception if test has
        # failed
        return list(f.result() for f in futures)


def func():
    py_values = [123, 234, 345]
    with Context() as ctx:
        module = Module.create(loc=Location.file("foo.txt", 0, 0))

        dtype = IntegerType.get_signless(64)
        with InsertionPoint(module.body), Location.name("a"):
        # with Location.name("a"):
            arith.constant(dtype, py_values[0])

        with InsertionPoint(module.body), Location.name("b"):
        # with Location.name("b"):
            arith.constant(dtype, py_values[1])

        with InsertionPoint(module.body), Location.name("c"):
        # with Location.name("c"):
            arith.constant(dtype, py_values[2])

    return str(module)


def func2():
    py_values = [123, 234, 345]
    with Context() as ctx, Location.file("foo.txt", 0, 0):
        module = Module.create()
        with InsertionPoint(module.body):
            dtype = IntegerType.get_signless(64)
            arith.constant(dtype, py_values[0])

    return str(module)

def test(func, num_threads=10, expected_first = False):

    if expected_first:
        expected_mlir = func()
        print("Expected MLIR:", expected_mlir)

    output_mlir_list = mt_run(func, num_threads=num_threads)

    if not expected_first:
        expected_mlir = func()
        print("Expected MLIR:", expected_mlir)

    for i, output_mlir in enumerate(output_mlir_list):
        assert output_mlir == expected_mlir, (i, output_mlir, expected_mlir)


def main(
    n: int = 100,
    name: str = "test",
    nt: int = 10,
    ef: bool = False,
):
    test_fn = {
        "test": func,
        "test2": func2,
    }[name]
    for i in range(n):
        print("- Count: ", i)
        test(test_fn, num_threads=nt, expected_first=ef)


if __name__ == "__main__":
    typer.run(main)

Ideally, making existing tests to run in a multi-threaded execution (either providing a manual implementation or using tools like: https://github.com/Quansight-Labs/pytest-run-parallel).
Seems like lit is running tests and using stdout output checks which may not always work correctly with multi-treaded execution... Do you think we could reuse somehow exising tests with lit in multi-threaded execution?

we should come up with a convention to protect those with a global mutex. I know there is an idiom for this in CPython itself, but is there a common thing done for pybind/extensions yet?

Yes, there is an example in pybind11 for that:
https://github.com/pybind/pybind11/blob/1f8b4a7f1a1c5cc9bd6e0d63fe15540e6c458b24/include/pybind11/detail/internals.h#L645-L649

I applied a similar thing for getLiveContexts (locally):

#ifdef Py_GIL_DISABLED
  static PyMutex &getLock() {
    static PyMutex lock;
    return lock;
  }
#endif

  template<typename F>
  static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
    auto &liveContexts = getLiveContexts();
#ifdef Py_GIL_DISABLED
    auto &lock = getLock();
    PyMutex_Lock(&lock);
#endif
    auto result = cb(liveContexts);
#ifdef Py_GIL_DISABLED
    PyMutex_Unlock(&lock);
#endif
    return result;
  }

Usage:

PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
  py::gil_scoped_acquire acquire;

  withLiveContexts([&](LiveContextMap& liveContexts) {
    liveContexts[context.ptr] = this;
    return this;
  });

}

@stellaraccident
Copy link
Contributor

stellaraccident commented Sep 22, 2024

That locking idiom looks good to me. Thanks.

For the tests, I won't claim to love the current testing style, but as with most such things, it is costly to change. If this weren't a part of llvm, I never would have written the tests this way, but I digress...

I suppose we should discuss what level of threading compatibility we are targeting (and likely document that). It seems where this patch is going is ensuring that the library itself does not provide any inherent locking hazards if independent contexts are used from multiple threads. Beyond that, the thread safety is approximately what it is for the underlying mlir libraries, which is to say that it is unsafe to perform mutation on any IR structure that you are unsure as to whether it is being exclusively access by a single thread. Most non framework code will just assume thread per context as the safe way to use the API.

That is quite a bit less safe than the CPython itself, which guards mutable data structures (list, shared dict, etc) internally, making it possible to have logical consistency problems is used across threads in an unsafe way, but internally guarding against corruption of the data structures themselves.

I'm not aware of a consensus that has emerged on the level of safety that a library such as this should provide. Fine grained locking at the binding layer seems like it would be expensive and error prone, especially given the more relaxed stance of the underlying library.

Thoughts?

(The reason this applies to testing is that if we are doing course locking of globals only, then we can likely target some specific tests that verify that. But if doing fine grained locking of IR structures, we probably need to rewrite the test suite into something that can be better verified with parallel testing)

@vfdev-5
Copy link
Author

vfdev-5 commented Sep 23, 2024

Thanks for the reply @stellaraccident !

Yes, we should discuss what level of threading compatibility we are targeting.

It seems where this patch is going is ensuring that the library itself does not provide any inherent locking hazards if independent contexts are used from multiple threads.

IMO the idea behind enabling CPython free-threading mode in python extensions (like MLIR python bindings) is to make them thread-safe as they are with GIL-enabled CPython.

I'm not aware of a consensus that has emerged on the level of safety that a library such as this should provide. Fine grained locking at the binding layer seems like it would be expensive and error prone, especially given the more relaxed stance of the underlying library.

Can you please detail this statement? I was thinking that we wanted to make python bindings thread-safe for free-threading cpython...

Few links on thread safety discussion within free-threading:

@stellaraccident
Copy link
Contributor

Can you please detail this statement? I was thinking that we wanted to make python bindings thread-safe for free-threading cpython...

If you want to make them thread-safe, then we will need to add much more locking, likely choosing to use context-scoped locks. The issue is that the underlying MLIR library is thread-compatible but not thread-safe. The GIL is currently making usage from Python more thread safe than MLIR actually is. Concretely, think of two threads attempting to modify the operations in the same function as an example of what would happen without more binding-level locking.

Given the bleeding edge nature of free-threaded Python, I don't think we need to get to thread-safe in one step, but we should document and decide on where we are trying to take things. Then we could document a big warning saying that the library is not presently fully thread safe in [list ways].

I'm supportive of making incremental progress in the right direction, so long as we document where we are trying to get to and are clear with folks what the current safety level is.

@vfdev-5
Copy link
Author

vfdev-5 commented Sep 24, 2024

@stellaraccident thanks for the discussion, I agree with your suggestion about documenting the scope of what free-threading would enable.

Concerning this part:

If you want to make them thread-safe, then we will need to add much more locking, likely choosing to use context-scoped locks. The issue is that the underlying MLIR library is thread-compatible but not thread-safe. The GIL is currently making usage from Python more thread safe than MLIR actually is. Concretely, think of two threads attempting to modify the operations in the same function as an example of what would happen without more binding-level locking.

I would like still to highlight the point that free-threading aims:

to retains the thread safety guarantees that were in place before the GIL was removed, and it doesn’t either add nor remove any. So that, for instance, my_int += 1 is unsafe both in the free-threaded and in the default build.
The fact that there is no GIL doesn’t mean that everything must now be thread safe.

(source: https://discuss.python.org/t/free-threading-trove-classifier/62406/24)

A similar situation is now with Numpy which is thread-unsafe to use from both the GIL-enabled and GIL-disabled build and user needs to be careful not to mutate shared state (cc @ngoldbaum).

I'm supportive of making incremental progress in the right direction, so long as we document where we are trying to get to and are clear with folks what the current safety level is.

Sounds good, so let's to make MLIR Python bindings tread-safe in free threading mode and document it.

@stellaraccident
Copy link
Contributor

Yes, but there is an important distinction: my_int += 1 will not crash in either build. Mutating shared IR state from multiple threads with the gil would produce inconsistent/logical errors but would not crash. Without the gil and without some form of extra locking, it will crash as the underlying IR structures are thread compatible but not thread safe.

If the cost of fine grain locking is excessive, many projects (mine included) would probably opt to build the extension without locking, since they always use it in thread per context anyway and the IR building is not a user exposed thing.

@hawkinsp
Copy link
Contributor

I'll also +1 that adding fine-grain locking may be problematic. The MLIR Python builders are already slow, and adding lots of fine grained locking, even uncontended locking, will make that worse. But in JAX, as in Stella's case, we never access the same IR objects from multiple threads, and if we did need to do that we could add client-side locking. I think that's by far the most important mode to support.

@stellaraccident
Copy link
Contributor

I'll also +1 that adding fine-grain locking may be problematic. The MLIR Python builders are already slow, and adding lots of fine grained locking, even uncontended locking, will make that worse. But in JAX, as in Stella's case, we never access the same IR objects from multiple threads, and if we did need to do that we could add client-side locking. I think that's by far the most important mode to support.

Off topic, but one of these days, I'm going to have to create a stripped down, build only (probably nanobind based) fast version of the API. Not at the top of the list right now, but yeah... slow. On the torch side, we only use the non-ODS API so that we have more flexibility here in the future.

@hawkinsp
Copy link
Contributor

Off topic, but one of these days, I'm going to have to create a stripped down, build only (probably nanobind based) fast version of the API. Not at the top of the list right now, but yeah... slow. On the torch side, we only use the non-ODS API so that we have more flexibility here in the future.

Yeah, at some point soon I'm either going to end up trying to send PRs to speed up the existing bindings (switch to nanobind?) or I'll migrate off them and do my own thing.

@stellaraccident
Copy link
Contributor

Off topic, but one of these days, I'm going to have to create a stripped down, build only (probably nanobind based) fast version of the API. Not at the top of the list right now, but yeah... slow. On the torch side, we only use the non-ODS API so that we have more flexibility here in the future.

Yeah, at some point soon I'm either going to end up trying to send PRs to speed up the existing bindings (switch to nanobind?) or I'll migrate off them and do my own thing.

I haven't benchmarked it in detail, but there are quite a bit of "productivity features" in there. I don't think for fastest-path builder use, the general IR manipulation APIs are going to be super easy to upgrade -- and they do get use. If the whole thing was ported to nanobind, I'd feel better about contributing a build-only fast-path subset library that could live alongside. In the current state, that would go out of tree because if going to the trouble to write something fast, I'm not going to base it on pybind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants