From d234393a07199a8118bcc554d5bc6d2bc5fcf44c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Aug 2024 09:40:23 +0200 Subject: [PATCH] Added CPython free-threading support and basic CI env + multithread tests + wheels job --- .github/workflows/test.yml | 30 ++++++++++++++++++++++ .github/workflows/wheels.yml | 9 ++++--- ml_dtypes/_src/dtypes.cc | 5 ++++ ml_dtypes/tests/finfo_test.py | 5 ++++ ml_dtypes/tests/iinfo_test.py | 5 ++++ ml_dtypes/tests/intn_test.py | 9 +++++++ ml_dtypes/tests/metadata_test.py | 5 ++++ ml_dtypes/tests/multi_thread_test_mixin.py | 26 +++++++++++++++++++ pyproject.toml | 8 +++--- 9 files changed, 96 insertions(+), 6 deletions(-) create mode 100644 ml_dtypes/tests/multi_thread_test_mixin.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e30946eb..6a7fba8d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,3 +96,33 @@ jobs: - name: Run tests run: | pytest -n auto + build-free-threading: + # Later we can merge this job with build similarly to + # https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57 + name: Python 3.13 with free-threading + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Set up Python 3.13 with free-threading + # TODO: replace with setup-python when there is support + uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0 + with: + python-version: '3.13-dev' + nogil: true + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools wheel + python -m pip install -U --pre numpy \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + env: + PYTHON_GIL: 0 + run: | + pytest -n auto diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 4d96082c..2db41900 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -40,17 +40,20 @@ jobs: platforms: all - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.15.0 + run: python -m pip install cibuildwheel==2.20.0 - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ARCHS_LINUX: auto aarch64 CIBW_ARCHS_MACOS: universal2 - CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* - CIBW_SKIP: "*musllinux* *i686* *win32*" + CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313* + CIBW_FREE_THREADED_SUPPORT: True + CIBW_PRERELEASE_PYTHONS: True + CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist CIBW_TEST_COMMAND: pytest -n auto {project} + CIBW_BUILD_VERBOSITY: 1 - uses: actions/upload-artifact@v3 with: diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 1dde49b4..68976e41 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -403,6 +403,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { return nullptr; } + +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED); +#endif + return m.release(); } } // namespace ml_dtypes diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 855c00ba..78903ba6 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -16,6 +16,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_test_mixin import MultiThreadTestMixin ALL_DTYPES = [ ml_dtypes.bfloat16, @@ -108,5 +109,9 @@ def assert_zero(val): ) +class FinfoMultiThreadTest(FinfoTest, MultiThreadTestMixin): + pass + + if __name__ == "__main__": absltest.main() diff --git a/ml_dtypes/tests/iinfo_test.py b/ml_dtypes/tests/iinfo_test.py index 66d0aa3d..7160699d 100644 --- a/ml_dtypes/tests/iinfo_test.py +++ b/ml_dtypes/tests/iinfo_test.py @@ -16,6 +16,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_test_mixin import MultiThreadTestMixin class IinfoTest(parameterized.TestCase): @@ -79,5 +80,9 @@ def testIinfoNonInteger(self): ml_dtypes.iinfo(bool) +class IinfoMultiThreadTest(IinfoTest, MultiThreadTestMixin): + pass + + if __name__ == "__main__": absltest.main() diff --git a/ml_dtypes/tests/intn_test.py b/ml_dtypes/tests/intn_test.py index 3adda76f..fdfa521f 100644 --- a/ml_dtypes/tests/intn_test.py +++ b/ml_dtypes/tests/intn_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_test_mixin import MultiThreadTestMixin int2 = ml_dtypes.int2 int4 = ml_dtypes.int4 @@ -380,5 +381,13 @@ def testBinaryUfuncs(self, scalar_type, ufunc): ) +class ScalarMultiThreadTest(ScalarTest, MultiThreadTestMixin): + pass + + +class ArrayMultiThreadTest(ArrayTest, MultiThreadTestMixin): + pass + + if __name__ == "__main__": absltest.main() diff --git a/ml_dtypes/tests/metadata_test.py b/ml_dtypes/tests/metadata_test.py index bb2335b4..3380f76e 100644 --- a/ml_dtypes/tests/metadata_test.py +++ b/ml_dtypes/tests/metadata_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest import ml_dtypes +from multi_thread_test_mixin import MultiThreadTestMixin class CustomFloatTest(absltest.TestCase): @@ -31,5 +32,9 @@ def test_version_matches_package_metadata(self): self.assertEqual(metadata_version, package_version) +class CustomFloatMultiThreadTest(CustomFloatTest, MultiThreadTestMixin): + pass + + if __name__ == "__main__": absltest.main() diff --git a/ml_dtypes/tests/multi_thread_test_mixin.py b/ml_dtypes/tests/multi_thread_test_mixin.py new file mode 100644 index 00000000..694f7fd7 --- /dev/null +++ b/ml_dtypes/tests/multi_thread_test_mixin.py @@ -0,0 +1,26 @@ +class MultiThreadTestMixin: + max_workers: int = 2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + test_fn_names = [v for v in dir(self) if v.startswith("test")] + + for test_fn_name in test_fn_names: + test_fn = getattr(self, test_fn_name) + + def get_mt_test_fn(test_func): + from concurrent.futures import ThreadPoolExecutor + from functools import wraps + + @wraps(test_func) + def wrapper(*args, **kwargs): + def test_func_noargs(_): + test_func(*args, **kwargs) + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + list(executor.map(test_func_noargs, range(self.max_workers))) + + wrapper.__doc__ == test_func.__doc__ + return wrapper + + setattr(self, test_fn.__name__, get_mt_test_fn(test_fn)) diff --git a/pyproject.toml b/pyproject.toml index 84287a70..a4d5653a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "numpy>=1.21.2; python_version>='3.10'", "numpy>=1.23.3; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", + "numpy>=2.1.0; python_version>='3.13'", ] [project.urls] @@ -48,9 +49,10 @@ pyink-use-majority-quotes = true [build-system] requires = [ - # We must build against NumPy 2.0 for the resulting wheels to - # be compatible with both NumPy 1.X and 2.X. - "numpy~=2.0.0", + # We build against the oldest NumPy 2.x release that + # supports each Python version. + "numpy~=2.0.0; python_version<'3.13'", + "numpy~=2.1.0; python_version>='3.13'", "setuptools~=70.1.1", ] build-backend = "setuptools.build_meta"