Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): hybrid descriptor #4275

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
), f"number of atom types in {ii}th descriptor {self.descrpt_list[0].__class__.__name__} does not match others"
# if hybrid sel is larger than sub sel, the nlist needs to be cut for each type
hybrid_sel = self.get_sel()
self.nlist_cut_idx: list[np.ndarray] = []
nlist_cut_idx: list[np.ndarray] = []
if self.mixed_types() and not all(
descrpt.mixed_types() for descrpt in self.descrpt_list
):
Expand All @@ -92,7 +93,8 @@ def __init__(
cut_idx = np.concatenate(
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
)
self.nlist_cut_idx.append(cut_idx)
nlist_cut_idx.append(cut_idx)
self.nlist_cut_idx = nlist_cut_idx

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -242,6 +244,7 @@ def call(
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
out_descriptor = []
out_gr = []
out_g2 = None
Expand All @@ -258,7 +261,7 @@ def call(
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
# cut the nlist to the correct length
if self.mixed_types() == descrpt.mixed_types():
nl = nlist[:, :, nci]
nl = xp.take(nlist, nci, axis=2)
else:
# mixed_types is True, but descrpt.mixed_types is False
assert nl_distinguish_types is not None
Expand All @@ -268,8 +271,8 @@ def call(
if gr is not None:
out_gr.append(gr)

out_descriptor = np.concatenate(out_descriptor, axis=-1)
out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None
out_descriptor = xp.concat(out_descriptor, axis=-1)
out_gr = xp.concat(out_gr, axis=-2) if out_gr else None
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -13,4 +16,5 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptHybrid",
]
26 changes: 26 additions & 0 deletions deepmd/jax/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("hybrid")
@flax_module
class DescrptHybrid(DescrptHybridDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"nlist_cut_idx"}:
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
elif name in {"descrpt_list"}:
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]

return super().__setattr__(name, value)
4 changes: 2 additions & 2 deletions doc/model/train-hybrid.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

This descriptor hybridizes multiple descriptors to form a new descriptor. For example, we have a list of descriptors denoted by $\mathcal D_1$, $\mathcal D_2$, ..., $\mathcal D_N$, the hybrid descriptor this the concatenation of the list, i.e. $\mathcal D = (\mathcal D_1, \mathcal D_2, \cdots, \mathcal D_N)$.
Expand Down
19 changes: 19 additions & 0 deletions source/tests/array_api_strict/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dpa1 import (
DescrptDPA1,
)
from .hybrid import (
DescrptHybrid,
)
from .se_e2_a import (
DescrptSeA,
)
from .se_e2_r import (
DescrptSeR,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptHybrid",
]
11 changes: 11 additions & 0 deletions source/tests/array_api_strict/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.make_base_descriptor import (
make_base_descriptor,
)

# no type annotations standard in array api
BaseDescriptor = make_base_descriptor(Any)
5 changes: 5 additions & 0 deletions source/tests/array_api_strict/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from ..utils.type_embed import (
TypeEmbedNet,
)
from .base_descriptor import (
BaseDescriptor,
)


class GatedAttentionLayer(GatedAttentionLayerDP):
Expand Down Expand Up @@ -72,6 +75,8 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
class DescrptDPA1(DescrptDPA1DP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_atten":
Expand Down
24 changes: 24 additions & 0 deletions source/tests/array_api_strict/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP

from ..common import (
to_array_api_strict_array,
)
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("hybrid")
class DescrptHybrid(DescrptHybridDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"nlist_cut_idx"}:
value = [to_array_api_strict_array(vv) for vv in value]
elif name in {"descrpt_list"}:
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]

return super().__setattr__(name, value)
5 changes: 5 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
from ..utils.network import (
NetworkCollection,
)
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
Expand Down
5 changes: 5 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
from ..utils.network import (
NetworkCollection,
)
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("se_e2_r")
@BaseDescriptor.register("se_r")
class DescrptSeR(DescrptSeRDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
Expand Down
35 changes: 35 additions & 0 deletions source/tests/consistent/descriptor/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -28,6 +30,16 @@
from deepmd.tf.descriptor.hybrid import DescrptHybrid as DescrptHybridTF
else:
DescrptHybridTF = None
if INSTALLED_JAX:
from deepmd.jax.descriptor.hybrid import DescrptHybrid as DescrptHybridJAX
else:
DescrptHybridJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.hybrid import (
DescrptHybrid as DescrptHybridStrict,
)
else:
DescrptHybridStrict = None
from deepmd.utils.argcheck import (
descrpt_hybrid_args,
)
Expand Down Expand Up @@ -68,8 +80,13 @@ def data(self) -> dict:
tf_class = DescrptHybridTF
dp_class = DescrptHybridDP
pt_class = DescrptHybridPT
jax_class = DescrptHybridJAX
array_api_strict_class = DescrptHybridStrict
args = descrpt_hybrid_args()

skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)

Expand Down Expand Up @@ -132,5 +149,23 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)