Skip to content

Commit

Permalink
feat: support array API (#3922)
Browse files Browse the repository at this point in the history
Fix #3430.
This PR sets up the basic support for the array API, and make an example
function (`compute_smooth_weight`) to support the array API. I believe
NumPy and JAX have supported it (or through `array-api-compat`), so we
don't need to write things twice for NumPy and JAX (although we can
write them using the ChatGPT, it's still better to maintain only one
thing). There are some challeging to use it in the TorchScript, so I
give it up. Supporting more function can be implemented in the following
PRs.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced testing for `compute_smooth_weight` function using
`array_api_strict` for enhanced array operations.

- **Chores**
- Updated dependencies to include `'array-api-compat'` and
`'array-api-strict>=2'` for improved compatibility and testing
capabilities.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Jun 28, 2024
1 parent cf8bd2a commit 56c3e17
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 2 deletions.
29 changes: 29 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""


def support_array_api(version: str) -> callable:
"""Mark a function as supporting the specific version of the array API.
Parameters
----------
version : str
The version of the array API
Returns
-------
callable
The decorated function
Examples
--------
>>> @support_array_api(version="2022.12")
... def f(x):
... pass
"""

def set_version(func: callable) -> callable:
func.array_api_version = version
return func

return set_version
12 changes: 10 additions & 2 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
support_array_api,
)


@support_array_api(version="2022.12")
def compute_smooth_weight(
distance: np.ndarray,
rmin: float,
Expand All @@ -19,12 +24,15 @@ def compute_smooth_weight(
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
xp = array_api_compat.array_namespace(distance)
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = np.logical_not(np.logical_or(min_mask, max_mask))
mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask))
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0
return vv * mid_mask + min_mask
return vv * xp.astype(mid_mask, distance.dtype) + xp.astype(
min_mask, distance.dtype
)


def _make_env_mat(
Expand Down
2 changes: 2 additions & 0 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ As a reference backend, it is not aimed at the best performance, but only the co
The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent.
Only Python inference interface can load this format.

NumPy 1.21 or above is required.

## Switch the backend

### Training
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
'packaging',
'ml_dtypes',
'mendeleev',
'array-api-compat',
]
requires-python = ">=3.8"
keywords = ["deepmd"]
Expand Down Expand Up @@ -79,6 +80,7 @@ test = [
"pytest-sugar",
"pytest-split",
"dpgui",
'array-api-strict>=2;python_version>="3.9"',
]
docs = [
"sphinx>=3.1.1",
Expand Down
2 changes: 2 additions & 0 deletions source/tests/common/dpmodel/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Test array API compatibility to be completely sure their usage of the array API is portable."""
30 changes: 30 additions & 0 deletions source/tests/common/dpmodel/array_api/test_env_mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
import unittest

if sys.version_info >= (3, 9):
import array_api_strict as xp
else:
raise unittest.SkipTest("array_api_strict doesn't support Python<=3.8")

from deepmd.dpmodel.utils.env_mat import (
compute_smooth_weight,
)

from .utils import (
ArrayAPITest,
)


class TestEnvMat(unittest.TestCase, ArrayAPITest):
def test_compute_smooth_weight(self):
self.set_array_api_version(compute_smooth_weight)
d = xp.arange(10, dtype=xp.float64)
w = compute_smooth_weight(
d,
4.0,
6.0,
)
self.assert_namespace_equal(w, d)
self.assert_device_equal(w, d)
self.assert_dtype_equal(w, d)
27 changes: 27 additions & 0 deletions source/tests/common/dpmodel/array_api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import array_api_compat
from array_api_strict import (
set_array_api_strict_flags,
)


class ArrayAPITest:
"""Utils for array API tests."""

def set_array_api_version(self, func):
"""Set the array API version for a function."""
set_array_api_strict_flags(api_version=func.array_api_version)

def assert_namespace_equal(self, a, b):
"""Assert two array has the same namespace."""
self.assertEqual(
array_api_compat.array_namespace(a), array_api_compat.array_namespace(b)
)

def assert_dtype_equal(self, a, b):
"""Assert two array has the same dtype."""
self.assertEqual(a.dtype, b.dtype)

def assert_device_equal(self, a, b):
"""Assert two array has the same device."""
self.assertEqual(array_api_compat.device(a), array_api_compat.device(b))

0 comments on commit 56c3e17

Please sign in to comment.