diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4711f2dd..67327b1c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,3 +96,31 @@ jobs: - name: Run tests run: | pytest -n auto + build-free-threading: + # Later we can merge this job with build similarly to + # https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57 + name: Python 3.13 with free-threading + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4 + with: + submodules: true + - name: Set up Python 3.13 with free-threading + # TODO: replace with setup-python when there is support + uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0 + with: + python-version: '3.13-dev' + nogil: true + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install setuptools wheel + python -m pip install -U --pre numpy \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + run: | + pytest -n auto diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index db1859f9..af91f615 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -40,17 +40,20 @@ jobs: platforms: all - name: Install cibuildwheel - run: python -m pip install cibuildwheel==2.15.0 + run: python -m pip install cibuildwheel==2.20.0 - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ARCHS_LINUX: auto aarch64 CIBW_ARCHS_MACOS: universal2 - CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* - CIBW_SKIP: "*musllinux* *i686* *win32*" + CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313* + CIBW_FREE_THREADED_SUPPORT: True + CIBW_PRERELEASE_PYTHONS: True + CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist CIBW_TEST_COMMAND: pytest -n auto {project} + CIBW_BUILD_VERBOSITY: 1 - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4 with: diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index 5075728b..287a60bf 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -460,6 +460,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { reinterpret_cast(TypeDescriptor::type_ptr)) < 0) { return nullptr; } + +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED); +#endif + return m.release(); } } // namespace ml_dtypes diff --git a/ml_dtypes/tests/conftest.py b/ml_dtypes/tests/conftest.py new file mode 100644 index 00000000..21196949 --- /dev/null +++ b/ml_dtypes/tests/conftest.py @@ -0,0 +1,21 @@ +# Copyright 2024 The ml_dtypes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""pytest configuration file.""" + +import pathlib +import sys + +# Add ml_dtypes/tests folder to discover multi_thread_utils.py module +sys.path.insert(0, str(pathlib.Path(__file__).absolute().parent)) diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 54a37d01..00f9b1a0 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -27,6 +27,7 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes +from multi_thread_utils import multi_threaded import numpy as np bfloat16 = ml_dtypes.bfloat16 @@ -196,6 +197,10 @@ def dtype_has_inf(dtype): # pylint: disable=g-complex-comprehension +@multi_threaded( + num_workers=3, + skip_tests=["testDiv", "testRoundTripNumpyTypes", "testRoundTripToNumpy"], +) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} @@ -604,6 +609,19 @@ def testDtypeFromString(self, float_type): # pylint: disable=g-complex-comprehension +@multi_threaded( + num_workers=3, + skip_tests=[ + "testBinaryUfunc", + "testConformNumpyComplex", + "testFloordivCornerCases", + "testDivmodCornerCases", + "testSpacing", + "testUnaryUfunc", + "testCasts", + "testLdexp", + ], +) @parameterized.named_parameters( ( {"testcase_name": "_" + dtype.__name__, "float_type": dtype} diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index ab92ea2f..c7135fc7 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes +from multi_thread_utils import multi_threaded import numpy as np ALL_DTYPES = [ @@ -41,6 +42,7 @@ } +@multi_threaded(num_workers=3) class FinfoTest(parameterized.TestCase): def assertNanEqual(self, x, y): diff --git a/ml_dtypes/tests/iinfo_test.py b/ml_dtypes/tests/iinfo_test.py index 66d0aa3d..8936c523 100644 --- a/ml_dtypes/tests/iinfo_test.py +++ b/ml_dtypes/tests/iinfo_test.py @@ -15,9 +15,11 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes +from multi_thread_utils import multi_threaded import numpy as np +@multi_threaded(num_workers=3) class IinfoTest(parameterized.TestCase): def testIinfoInt2(self): diff --git a/ml_dtypes/tests/intn_test.py b/ml_dtypes/tests/intn_test.py index 3adda76f..86ab5a81 100644 --- a/ml_dtypes/tests/intn_test.py +++ b/ml_dtypes/tests/intn_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized import ml_dtypes +from multi_thread_utils import multi_threaded import numpy as np int2 = ml_dtypes.int2 @@ -48,6 +49,7 @@ def ignore_warning(**kw): # Tests for the Python scalar type +@multi_threaded(num_workers=3) class ScalarTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) @@ -245,6 +247,7 @@ def testCanCast(self, a, b): # Tests for the Python scalar type +@multi_threaded(num_workers=3, skip_tests=["testBinaryUfuncs"]) class ArrayTest(parameterized.TestCase): @parameterized.product(scalar_type=INTN_TYPES) diff --git a/ml_dtypes/tests/metadata_test.py b/ml_dtypes/tests/metadata_test.py index bb2335b4..81da5367 100644 --- a/ml_dtypes/tests/metadata_test.py +++ b/ml_dtypes/tests/metadata_test.py @@ -16,8 +16,10 @@ from absl.testing import absltest import ml_dtypes +from multi_thread_utils import multi_threaded +@multi_threaded(num_workers=3) class CustomFloatTest(absltest.TestCase): def test_version_matches_package_metadata(self): diff --git a/ml_dtypes/tests/multi_thread_utils.py b/ml_dtypes/tests/multi_thread_utils.py new file mode 100644 index 00000000..bd901906 --- /dev/null +++ b/ml_dtypes/tests/multi_thread_utils.py @@ -0,0 +1,50 @@ +# Copyright 2024 The ml_dtypes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for multi-threaded tests.""" + +import concurrent.futures +import functools +from typing import Optional + + +def multi_threaded(*, num_workers: int, skip_tests: Optional[list[str]] = None): + """Decorator that runs a test in a multi-threaded environment.""" + + def decorator(test_cls): + for name, test_fn in test_cls.__dict__.copy().items(): + if not (name.startswith("test") and callable(test_fn)): + continue + + if skip_tests is not None: + if any(test_name in name for test_name in skip_tests): + continue + + @functools.wraps(test_fn) # pylint: disable=cell-var-from-loop + def multi_threaded_test_fn(*args, __test_fn__=test_fn, **kwargs): + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers + ) as executor: + futures = [] + for _ in range(num_workers): + futures.append(executor.submit(__test_fn__, *args, **kwargs)) + # We should call future.result() to re-raise an exception if test has + # failed + list(f.result() for f in futures) + + setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn) + + return test_cls + + return decorator diff --git a/pyproject.toml b/pyproject.toml index 147abd55..661b7b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "numpy>=1.21.2; python_version>='3.10'", "numpy>=1.23.3; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", + "numpy>=2.1.0; python_version>='3.13'", ] [project.urls]