From 9b22feb92da42b86b19ed0c54cd30d2baa960b4f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Aug 2024 09:40:23 +0200 Subject: [PATCH] Added CPython free-threading support and basic CI env + multithread tests + wheels job --- .github/workflows/test.yml | 28 ++++++++++++++++++++++++++ .github/workflows/wheels.yml | 9 ++++++--- ml_dtypes/_src/dtypes.cc | 5 +++++ ml_dtypes/tests/conftest.py | 5 +++++ ml_dtypes/tests/custom_float_test.py | 18 +++++++++++++++++ ml_dtypes/tests/finfo_test.py | 2 ++ ml_dtypes/tests/iinfo_test.py | 2 ++ ml_dtypes/tests/intn_test.py | 3 +++ ml_dtypes/tests/metadata_test.py | 2 ++ ml_dtypes/tests/multi_thread_utils.py | 29 +++++++++++++++++++++++++++ pyproject.toml | 1 + 11 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 ml_dtypes/tests/conftest.py create mode 100644 ml_dtypes/tests/multi_thread_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4711f2dd..67327b1c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,3 +96,31 @@ jobs: - name: Run tests run: | pytest -n auto + build-free-threading: + # Later we can merge this job with build similarly to + # https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57 + name: Python 3.13 with free-threading + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 + with: + submodules: true + - name: Set up Python 3.13 with free-threading + # TODO: replace with setup-python when there is support + uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0 + with: + python-version: '3.13-dev' + nogil: true + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools wheel + python -m pip install -U --pre numpy \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + run: | + pytest -n auto diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index db1859f9..af91f615 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -40,17 +40,20 @@ jobs: platforms: all - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.15.0 + run: python -m pip install cibuildwheel==2.20.0 - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ARCHS_LINUX: auto aarch64 CIBW_ARCHS_MACOS: universal2 - CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* - CIBW_SKIP: "*musllinux* *i686* *win32*" + CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313* + CIBW_FREE_THREADED_SUPPORT: True + CIBW_PRERELEASE_PYTHONS: True + CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist CIBW_TEST_COMMAND: pytest -n auto {project} + CIBW_BUILD_VERBOSITY: 1 - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 84914704..90ac4bed 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -460,6 +460,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { return nullptr; } + +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED); +#endif + return m.release(); } } // namespace ml_dtypes diff --git a/ml_dtypes/tests/conftest.py b/ml_dtypes/tests/conftest.py new file mode 100644 index 00000000..0249c342 --- /dev/null +++ b/ml_dtypes/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +from pathlib import Path + +# Add ml_dtypes/tests folder to discover multi_thread_utils.py module +sys.path.insert(0, str(Path(__file__).absolute().parent)) diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 54a37d01..b0553119 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -28,6 +28,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded bfloat16 = ml_dtypes.bfloat16 float8_e3m4 = ml_dtypes.float8_e3m4 @@ -196,6 +197,10 @@ def dtype_has_inf(dtype): # pylint: disable=g-complex-comprehension +@multi_threaded( + num_workers=3, + skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"], +) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} @@ -604,6 +609,19 @@ def testDtypeFromString(self, float_type): # pylint: disable=g-complex-comprehension +@multi_threaded( + num_workers=3, + skip_tests=[ + "testBinaryUfunc", + "testConformNumpyComplex", + "testFloordivCornerCases", + "testDivmodCornerCases", + "testSpacing", + "testUnaryUfunc", + "testCasts", + "testLdexp", + ], +) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index ab92ea2f..de6560ab 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -16,6 +16,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded ALL_DTYPES = [ ml_dtypes.bfloat16, @@ -41,6 +42,7 @@ } +@multi_threaded(num_workers=3) class FinfoTest(parameterized.TestCase): def assertNanEqual(self, x, y): diff --git a/ml_dtypes/tests/iinfo_test.py b/ml_dtypes/tests/iinfo_test.py index 66d0aa3d..a17b410c 100644 --- a/ml_dtypes/tests/iinfo_test.py +++ b/ml_dtypes/tests/iinfo_test.py @@ -16,8 +16,10 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded +@multi_threaded(num_workers=3) class IinfoTest(parameterized.TestCase): def testIinfoInt2(self): diff --git a/ml_dtypes/tests/intn_test.py b/ml_dtypes/tests/intn_test.py index 3adda76f..babd60e8 100644 --- a/ml_dtypes/tests/intn_test.py +++ b/ml_dtypes/tests/intn_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded int2 = ml_dtypes.int2 int4 = ml_dtypes.int4 @@ -48,6 +49,7 @@ def ignore_warning(**kw): # Tests for the Python scalar type +@multi_threaded(num_workers=3) class ScalarTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) @@ -245,6 +247,7 @@ def testCanCast(self, a, b): # Tests for the Python scalar type +@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"]) class ArrayTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) diff --git a/ml_dtypes/tests/metadata_test.py b/ml_dtypes/tests/metadata_test.py index bb2335b4..81da5367 100644 --- a/ml_dtypes/tests/metadata_test.py +++ b/ml_dtypes/tests/metadata_test.py @@ -16,8 +16,10 @@ from absl.testing import absltest import ml_dtypes +from multi_thread_utils import multi_threaded +@multi_threaded(num_workers=3) class CustomFloatTest(absltest.TestCase): def test_version_matches_package_metadata(self): diff --git a/ml_dtypes/tests/multi_thread_utils.py b/ml_dtypes/tests/multi_thread_utils.py new file mode 100644 index 00000000..8f92fc6e --- /dev/null +++ b/ml_dtypes/tests/multi_thread_utils.py @@ -0,0 +1,29 @@ +from typing import Optional +from concurrent.futures import ThreadPoolExecutor +from functools import wraps + + +def multi_threaded(*, num_workers: int, skip_tests: Optional[list[str]] = None): + def decorator(test_cls): + for name, test_fn in test_cls.__dict__.copy().items(): + if not (name.startswith("test") and callable(test_fn)): + continue + + if skip_tests is not None: + if any(test_name in name for test_name in skip_tests): + continue + + @wraps(test_fn) + def multi_threaded_test_fn(*args, __test_fn__=test_fn, **kwargs): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for _ in range(num_workers): + futures.append(executor.submit(__test_fn__, *args, **kwargs)) + # We should call future.result() to re-raise an exception if test has failed + list(f.result() for f in futures) + + setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn) + + return test_cls + + return decorator diff --git a/pyproject.toml b/pyproject.toml index 147abd55..661b7b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "numpy>=1.21.2; python_version>='3.10'", "numpy>=1.23.3; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", + "numpy>=2.1.0; python_version>='3.13'", ] [project.urls]