From c9f6feeab155668fee362681e71d1c75b66fe8c8 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 21 Aug 2024 09:40:23 +0200 Subject: [PATCH] Added CPython free-threading support and basic CI env + multithread tests + wheels job --- .github/workflows/test.yml | 28 +++++++++++++++++++++++++++ .github/workflows/wheels.yml | 9 ++++++--- .pre-commit-config.yaml | 1 - ml_dtypes/_src/dtypes.cc | 5 +++++ ml_dtypes/tests/__init__.py | 0 ml_dtypes/tests/conftest.py | 5 +++++ ml_dtypes/tests/finfo_test.py | 2 ++ ml_dtypes/tests/iinfo_test.py | 2 ++ ml_dtypes/tests/intn_test.py | 17 +++++++--------- ml_dtypes/tests/metadata_test.py | 2 ++ ml_dtypes/tests/multi_thread_utils.py | 24 +++++++++++++++++++++++ pyproject.toml | 1 + 12 files changed, 82 insertions(+), 14 deletions(-) create mode 100644 ml_dtypes/tests/__init__.py create mode 100644 ml_dtypes/tests/conftest.py create mode 100644 ml_dtypes/tests/multi_thread_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4711f2dd..67327b1c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,3 +96,31 @@ jobs: - name: Run tests run: | pytest -n auto + build-free-threading: + # Later we can merge this job with build similarly to + # https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57 + name: Python 3.13 with free-threading + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 + with: + submodules: true + - name: Set up Python 3.13 with free-threading + # TODO: replace with setup-python when there is support + uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0 + with: + python-version: '3.13-dev' + nogil: true + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools wheel + python -m pip install -U --pre numpy \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + run: | + pytest -n auto diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index db1859f9..af91f615 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -40,17 +40,20 @@ jobs: platforms: all - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.15.0 + run: python -m pip install cibuildwheel==2.20.0 - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ARCHS_LINUX: auto aarch64 CIBW_ARCHS_MACOS: universal2 - CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* - CIBW_SKIP: "*musllinux* *i686* *win32*" + CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313* + CIBW_FREE_THREADED_SUPPORT: True + CIBW_PRERELEASE_PYTHONS: True + CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist CIBW_TEST_COMMAND: pytest -n auto {project} + CIBW_BUILD_VERBOSITY: 1 - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5638b893..99e638b5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,6 @@ repos: rev: 23.10.0 hooks: - id: pyink - language_version: python3.9 args: [ "--line-length=80", "--preview", diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 87f7578f..25a9bf6c 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -431,6 +431,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { return nullptr; } + +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED); +#endif + return m.release(); } } // namespace ml_dtypes diff --git a/ml_dtypes/tests/__init__.py b/ml_dtypes/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ml_dtypes/tests/conftest.py b/ml_dtypes/tests/conftest.py new file mode 100644 index 00000000..01aef641 --- /dev/null +++ b/ml_dtypes/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +from pathlib import Path + +# Add ml_dtypes/tests folder to discover multi_thread_utils.py module +sys.path.insert(0, str(Path(__file__).parent)) diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 3999476b..b3d94eca 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -16,6 +16,7 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded ALL_DTYPES = [ ml_dtypes.bfloat16, @@ -40,6 +41,7 @@ } +@multi_threaded(num_workers=3) class FinfoTest(parameterized.TestCase): def assertNanEqual(self, x, y): diff --git a/ml_dtypes/tests/iinfo_test.py b/ml_dtypes/tests/iinfo_test.py index 66d0aa3d..a17b410c 100644 --- a/ml_dtypes/tests/iinfo_test.py +++ b/ml_dtypes/tests/iinfo_test.py @@ -16,8 +16,10 @@ from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded +@multi_threaded(num_workers=3) class IinfoTest(parameterized.TestCase): def testIinfoInt2(self): diff --git a/ml_dtypes/tests/intn_test.py b/ml_dtypes/tests/intn_test.py index 3adda76f..14185b13 100644 --- a/ml_dtypes/tests/intn_test.py +++ b/ml_dtypes/tests/intn_test.py @@ -14,16 +14,15 @@ """Test cases for int4 types.""" -import contextlib import copy import operator import pickle -import warnings from absl.testing import absltest from absl.testing import parameterized import ml_dtypes import numpy as np +from multi_thread_utils import multi_threaded int2 = ml_dtypes.int2 int4 = ml_dtypes.int4 @@ -40,14 +39,8 @@ } -@contextlib.contextmanager -def ignore_warning(**kw): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) - yield - - # Tests for the Python scalar type +@multi_threaded(num_workers=3) class ScalarTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) @@ -245,6 +238,7 @@ def testCanCast(self, a, b): # Tests for the Python scalar type +@multi_threaded(num_workers=3) class ArrayTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) @@ -370,10 +364,13 @@ def testCastBetweenCustomTypes(self, types): np.remainder, ], ) - @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") def testBinaryUfuncs(self, scalar_type, ufunc): x = np.array(VALUES[scalar_type]) y = np.array(VALUES[scalar_type], dtype=scalar_type) + # Avoid the following "RuntimeWarning: divide by zero encountered" + if ufunc in (np.floor_divide, np.remainder): + x[x == 0] = 1 + y[y == 0] = 1 np.testing.assert_array_equal( ufunc(x[:, None], x[None, :]).astype(scalar_type), ufunc(y[:, None], y[None, :]), diff --git a/ml_dtypes/tests/metadata_test.py b/ml_dtypes/tests/metadata_test.py index bb2335b4..81da5367 100644 --- a/ml_dtypes/tests/metadata_test.py +++ b/ml_dtypes/tests/metadata_test.py @@ -16,8 +16,10 @@ from absl.testing import absltest import ml_dtypes +from multi_thread_utils import multi_threaded +@multi_threaded(num_workers=3) class CustomFloatTest(absltest.TestCase): def test_version_matches_package_metadata(self): diff --git a/ml_dtypes/tests/multi_thread_utils.py b/ml_dtypes/tests/multi_thread_utils.py new file mode 100644 index 00000000..b394eb84 --- /dev/null +++ b/ml_dtypes/tests/multi_thread_utils.py @@ -0,0 +1,24 @@ +from concurrent.futures import ThreadPoolExecutor +from functools import wraps + + +def multi_threaded(*, num_workers: int): + def decorator(test_cls): + for name, test_fn in test_cls.__dict__.copy().items(): + if not (name.startswith("test") and callable(test_fn)): + continue + + @wraps(test_fn) + def multi_threaded_test_fn(*args, __test_fn__=test_fn, **kwargs): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + for _ in range(num_workers): + futures.append(executor.submit(__test_fn__, *args, **kwargs)) + # We should call future.result() to re-raise an exception if test has failed + list(f.result() for f in futures) + + setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn) + + return test_cls + + return decorator diff --git a/pyproject.toml b/pyproject.toml index 147abd55..661b7b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "numpy>=1.21.2; python_version>='3.10'", "numpy>=1.23.3; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", + "numpy>=2.1.0; python_version>='3.13'", ] [project.urls]