diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py new file mode 100644 index 0000000000..e4af2ad627 --- /dev/null +++ b/deepmd/dpmodel/array_api.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Utilities for the array API.""" + + +def support_array_api(version: str) -> callable: + """Mark a function as supporting the specific version of the array API. + + Parameters + ---------- + version : str + The version of the array API + + Returns + ------- + callable + The decorated function + + Examples + -------- + >>> @support_array_api(version="2022.12") + ... def f(x): + ... pass + """ + + def set_version(func: callable) -> callable: + func.array_api_version = version + return func + + return set_version diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 94cf3a7c21..41f2591279 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -4,13 +4,18 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( NativeOP, ) +from deepmd.dpmodel.array_api import ( + support_array_api, +) +@support_array_api(version="2022.12") def compute_smooth_weight( distance: np.ndarray, rmin: float, @@ -19,12 +24,15 @@ def compute_smooth_weight( """Compute smooth weight for descriptor elements.""" if rmin >= rmax: raise ValueError("rmin should be less than rmax.") + xp = array_api_compat.array_namespace(distance) min_mask = distance <= rmin max_mask = distance >= rmax - mid_mask = np.logical_not(np.logical_or(min_mask, max_mask)) + mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask)) uu = (distance - rmin) / (rmax - rmin) vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0 - return vv * mid_mask + min_mask + return vv * xp.astype(mid_mask, distance.dtype) + xp.astype( + min_mask, distance.dtype + ) def _make_env_mat( diff --git a/doc/backend.md b/doc/backend.md index 2f0bc7ed20..e164cd8405 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -37,6 +37,8 @@ As a reference backend, it is not aimed at the best performance, but only the co The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent. Only Python inference interface can load this format. +NumPy 1.21 or above is required. + ## Switch the backend ### Training diff --git a/pyproject.toml b/pyproject.toml index d9cbeb44e4..861fea6399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ 'packaging', 'ml_dtypes', 'mendeleev', + 'array-api-compat', ] requires-python = ">=3.8" keywords = ["deepmd"] @@ -79,6 +80,7 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", + 'array-api-strict>=2;python_version>="3.9"', ] docs = [ "sphinx>=3.1.1", diff --git a/source/tests/common/dpmodel/array_api/__init__.py b/source/tests/common/dpmodel/array_api/__init__.py new file mode 100644 index 0000000000..e02301188e --- /dev/null +++ b/source/tests/common/dpmodel/array_api/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test array API compatibility to be completely sure their usage of the array API is portable.""" diff --git a/source/tests/common/dpmodel/array_api/test_env_mat.py b/source/tests/common/dpmodel/array_api/test_env_mat.py new file mode 100644 index 0000000000..d5bc7b6c18 --- /dev/null +++ b/source/tests/common/dpmodel/array_api/test_env_mat.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +import unittest + +if sys.version_info >= (3, 9): + import array_api_strict as xp +else: + raise unittest.SkipTest("array_api_strict doesn't support Python<=3.8") + +from deepmd.dpmodel.utils.env_mat import ( + compute_smooth_weight, +) + +from .utils import ( + ArrayAPITest, +) + + +class TestEnvMat(unittest.TestCase, ArrayAPITest): + def test_compute_smooth_weight(self): + self.set_array_api_version(compute_smooth_weight) + d = xp.arange(10, dtype=xp.float64) + w = compute_smooth_weight( + d, + 4.0, + 6.0, + ) + self.assert_namespace_equal(w, d) + self.assert_device_equal(w, d) + self.assert_dtype_equal(w, d) diff --git a/source/tests/common/dpmodel/array_api/utils.py b/source/tests/common/dpmodel/array_api/utils.py new file mode 100644 index 0000000000..7e422c2ead --- /dev/null +++ b/source/tests/common/dpmodel/array_api/utils.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat +from array_api_strict import ( + set_array_api_strict_flags, +) + + +class ArrayAPITest: + """Utils for array API tests.""" + + def set_array_api_version(self, func): + """Set the array API version for a function.""" + set_array_api_strict_flags(api_version=func.array_api_version) + + def assert_namespace_equal(self, a, b): + """Assert two array has the same namespace.""" + self.assertEqual( + array_api_compat.array_namespace(a), array_api_compat.array_namespace(b) + ) + + def assert_dtype_equal(self, a, b): + """Assert two array has the same dtype.""" + self.assertEqual(a.dtype, b.dtype) + + def assert_device_equal(self, a, b): + """Assert two array has the same device.""" + self.assertEqual(array_api_compat.device(a), array_api_compat.device(b))