Skip to content

Commit

Permalink
Added CPython free-threading support and basic CI env
Browse files Browse the repository at this point in the history
+ multithread tests
+ wheels job
  • Loading branch information
vfdev-5 committed Aug 28, 2024
1 parent f053b3c commit c9f6fee
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 14 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,31 @@ jobs:
- name: Run tests
run: |
pytest -n auto
build-free-threading:
# Later we can merge this job with build similarly to
# https://github.com/python-pillow/Pillow/blob/f0d8fd3059bc1b291563d8a0b1f224b6fd7d0b90/.github/workflows/test.yml#L56-L57
name: Python 3.13 with free-threading
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
with:
submodules: true
- name: Set up Python 3.13 with free-threading
# TODO: replace with setup-python when there is support
uses: deadsnakes/action@6c8b9b82fe0b4344f4b98f2775fcc395df45e494 # v3.1.0
with:
python-version: '3.13-dev'
nogil: true
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install setuptools wheel
python -m pip install -U --pre numpy \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
python -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes
run: |
python -m pip install .[dev] --no-build-isolation
- name: Run tests
run: |
pytest -n auto
9 changes: 6 additions & 3 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,20 @@ jobs:
platforms: all

- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.15.0
run: python -m pip install cibuildwheel==2.20.0

- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
CIBW_ARCHS_LINUX: auto aarch64
CIBW_ARCHS_MACOS: universal2
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-*
CIBW_SKIP: "*musllinux* *i686* *win32*"
CIBW_BUILD: cp39-* cp310-* cp311-* cp312-* cp313*
CIBW_FREE_THREADED_SUPPORT: True
CIBW_PRERELEASE_PYTHONS: True
CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*"
CIBW_TEST_REQUIRES: absl-py pytest pytest-xdist
CIBW_TEST_COMMAND: pytest -n auto {project}
CIBW_BUILD_VERBOSITY: 1

- uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # ratchet: actions/upload-artifact@v4
with:
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ repos:
rev: 23.10.0
hooks:
- id: pyink
language_version: python3.9
args: [
"--line-length=80",
"--preview",
Expand Down
5 changes: 5 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
reinterpret_cast<PyObject*>(TypeDescriptor<uint4>::type_ptr)) < 0) {
return nullptr;
}

#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(m.get(), Py_MOD_GIL_NOT_USED);
#endif

return m.release();
}
} // namespace ml_dtypes
Empty file added ml_dtypes/tests/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions ml_dtypes/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys
from pathlib import Path

# Add ml_dtypes/tests folder to discover multi_thread_utils.py module
sys.path.insert(0, str(Path(__file__).parent))
2 changes: 2 additions & 0 deletions ml_dtypes/tests/finfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import parameterized
import ml_dtypes
import numpy as np
from multi_thread_utils import multi_threaded

ALL_DTYPES = [
ml_dtypes.bfloat16,
Expand All @@ -40,6 +41,7 @@
}


@multi_threaded(num_workers=3)
class FinfoTest(parameterized.TestCase):

def assertNanEqual(self, x, y):
Expand Down
2 changes: 2 additions & 0 deletions ml_dtypes/tests/iinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from absl.testing import parameterized
import ml_dtypes
import numpy as np
from multi_thread_utils import multi_threaded


@multi_threaded(num_workers=3)
class IinfoTest(parameterized.TestCase):

def testIinfoInt2(self):
Expand Down
17 changes: 7 additions & 10 deletions ml_dtypes/tests/intn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@

"""Test cases for int4 types."""

import contextlib
import copy
import operator
import pickle
import warnings

from absl.testing import absltest
from absl.testing import parameterized
import ml_dtypes
import numpy as np
from multi_thread_utils import multi_threaded

int2 = ml_dtypes.int2
int4 = ml_dtypes.int4
Expand All @@ -40,14 +39,8 @@
}


@contextlib.contextmanager
def ignore_warning(**kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", **kw)
yield


# Tests for the Python scalar type
@multi_threaded(num_workers=3)
class ScalarTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down Expand Up @@ -245,6 +238,7 @@ def testCanCast(self, a, b):


# Tests for the Python scalar type
@multi_threaded(num_workers=3)
class ArrayTest(parameterized.TestCase):

@parameterized.product(scalar_type=INTN_TYPES)
Expand Down Expand Up @@ -370,10 +364,13 @@ def testCastBetweenCustomTypes(self, types):
np.remainder,
],
)
@ignore_warning(category=RuntimeWarning, message="divide by zero encountered")
def testBinaryUfuncs(self, scalar_type, ufunc):
x = np.array(VALUES[scalar_type])
y = np.array(VALUES[scalar_type], dtype=scalar_type)
# Avoid the following "RuntimeWarning: divide by zero encountered"
if ufunc in (np.floor_divide, np.remainder):
x[x == 0] = 1
y[y == 0] = 1
np.testing.assert_array_equal(
ufunc(x[:, None], x[None, :]).astype(scalar_type),
ufunc(y[:, None], y[None, :]),
Expand Down
2 changes: 2 additions & 0 deletions ml_dtypes/tests/metadata_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from absl.testing import absltest
import ml_dtypes
from multi_thread_utils import multi_threaded


@multi_threaded(num_workers=3)
class CustomFloatTest(absltest.TestCase):

def test_version_matches_package_metadata(self):
Expand Down
24 changes: 24 additions & 0 deletions ml_dtypes/tests/multi_thread_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from concurrent.futures import ThreadPoolExecutor
from functools import wraps


def multi_threaded(*, num_workers: int):
def decorator(test_cls):
for name, test_fn in test_cls.__dict__.copy().items():
if not (name.startswith("test") and callable(test_fn)):
continue

@wraps(test_fn)
def multi_threaded_test_fn(*args, __test_fn__=test_fn, **kwargs):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
for _ in range(num_workers):
futures.append(executor.submit(__test_fn__, *args, **kwargs))
# We should call future.result() to re-raise an exception if test has failed
list(f.result() for f in futures)

setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)

return test_cls

return decorator
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"numpy>=1.21.2; python_version>='3.10'",
"numpy>=1.23.3; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"numpy>=2.1.0; python_version>='3.13'",
]

[project.urls]
Expand Down

0 comments on commit c9f6fee

Please sign in to comment.