From cad90343c0941c455737458a74a6bc8985ad0a33 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 23 Sep 2024 20:54:12 -0400 Subject: [PATCH 01/14] feat(jax/array-api): dpa1 Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/dpa1.py | 241 +++++++++++++----- deepmd/dpmodel/utils/env_mat.py | 37 ++- deepmd/dpmodel/utils/exclude_mask.py | 33 ++- deepmd/dpmodel/utils/network.py | 38 +-- deepmd/dpmodel/utils/nlist.py | 114 +++++---- deepmd/dpmodel/utils/region.py | 34 ++- deepmd/dpmodel/utils/type_embed.py | 2 +- deepmd/jax/common.py | 4 +- deepmd/jax/descriptor/__init__.py | 1 + deepmd/jax/descriptor/dpa1.py | 80 ++++++ deepmd/jax/utils/exclude_mask.py | 16 ++ deepmd/jax/utils/network.py | 16 ++ source/tests/array_api_strict/__init__.py | 2 + source/tests/array_api_strict/common.py | 25 ++ .../array_api_strict/descriptor/__init__.py | 1 + .../tests/array_api_strict/descriptor/dpa1.py | 81 ++++++ .../tests/array_api_strict/utils/__init__.py | 1 + .../array_api_strict/utils/exclude_mask.py | 17 ++ .../tests/array_api_strict/utils/network.py | 46 ++++ .../array_api_strict/utils/type_embed.py | 22 ++ .../common/dpmodel/test_descriptor_dpa1.py | 19 ++ source/tests/consistent/common.py | 66 +++++ source/tests/consistent/descriptor/common.py | 63 +++++ .../tests/consistent/descriptor/test_dpa1.py | 96 +++++++ .../tests/consistent/test_type_embedding.py | 15 ++ 25 files changed, 900 insertions(+), 170 deletions(-) create mode 100644 deepmd/jax/descriptor/__init__.py create mode 100644 deepmd/jax/descriptor/dpa1.py create mode 100644 deepmd/jax/utils/exclude_mask.py create mode 100644 source/tests/array_api_strict/__init__.py create mode 100644 source/tests/array_api_strict/common.py create mode 100644 source/tests/array_api_strict/descriptor/__init__.py create mode 100644 source/tests/array_api_strict/descriptor/dpa1.py create mode 100644 source/tests/array_api_strict/utils/__init__.py create mode 100644 source/tests/array_api_strict/utils/exclude_mask.py create mode 100644 source/tests/array_api_strict/utils/network.py create mode 100644 source/tests/array_api_strict/utils/type_embed.py diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 70cb818eef..d16c87d37c 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -8,6 +8,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -34,9 +35,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -61,13 +59,16 @@ def np_softmax(x, axis=-1): - x = np.nan_to_num(x) # to avoid value warning - e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) - return e_x / np.sum(e_x, axis=axis, keepdims=True) + xp = array_api_compat.array_namespace(x) + # x = xp.nan_to_num(x) # to avoid value warning + x = xp.where(xp.isnan(x), xp.zeros_like(x), x) + e_x = xp.exp(x - xp.max(x, axis=axis, keepdims=True)) + return e_x / xp.sum(e_x, axis=axis, keepdims=True) def np_normalize(x, axis=-1): - return x / np.linalg.norm(x, axis=axis, keepdims=True) + xp = array_api_compat.array_namespace(x) + return x / xp.linalg.vector_norm(x, axis=axis, keepdims=True) @BaseDescriptor.register("se_atten") @@ -476,10 +477,14 @@ def call( The smooth switch function. """ del mapping + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) nf, nloc, nnei = nlist.shape - nall = coord_ext.reshape(nf, -1).shape[1] // 3 + nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call()[atype_ext] + atype_embd_ext = xp.reshape( + xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + (nf, nall, self.tebd_dim), + ) # nfnl x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] grrg, g2, h2, rot_mat, sw = self.se_atten( @@ -491,8 +496,8 @@ def call( ) # nf x nloc x (ng x ng1 + tebd_dim) if self.concat_output_tebd: - grrg = np.concatenate( - [grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1 + grrg = xp.concat( + [grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1 ) return grrg, rot_mat, None, None, sw @@ -538,8 +543,8 @@ def serialize(self) -> dict: "exclude_types": obj.exclude_types, "env_protection": obj.env_protection, "@variables": { - "davg": obj["davg"], - "dstd": obj["dstd"], + "davg": np.array(obj["davg"]), + "dstd": np.array(obj["dstd"]), }, ## to be updated when the options are supported. "trainable": self.trainable, @@ -685,12 +690,12 @@ def __init__( self.embd_input_dim = 1 + self.tebd_dim_input else: self.embd_input_dim = 1 - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings[0] = EmbeddingNet( + embeddings[0] = EmbeddingNet( self.embd_input_dim, self.neuron, self.activation_function, @@ -698,13 +703,14 @@ def __init__( self.precision, seed=child_seed(seed, 0), ) + self.embeddings = embeddings if self.tebd_input_mode in ["strip"]: - self.embeddings_strip = NetworkCollection( + embeddings_strip = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings_strip[0] = EmbeddingNet( + embeddings_strip[0] = EmbeddingNet( self.tebd_dim_input, self.neuron, self.activation_function, @@ -712,6 +718,7 @@ def __init__( self.precision, seed=child_seed(seed, 1), ) + self.embeddings_strip = embeddings_strip else: self.embeddings_strip = None self.dpa1_attention = NeighborGatedAttention( @@ -839,9 +846,10 @@ def cal_g( ss, embedding_idx, ): + xp = array_api_compat.array_namespace(ss) nfnl, nnei = ss.shape[0:2] - shape2 = np.prod(ss.shape[2:]) - ss = ss.reshape(nfnl, nnei, shape2) + shape2 = xp.prod(xp.asarray(ss.shape[2:])) + ss = xp.reshape(ss, (nfnl, nnei, shape2)) # nfnl x nnei x ng gg = self.embeddings[embedding_idx].call(ss) return gg @@ -852,9 +860,10 @@ def cal_g_strip( embedding_idx, ): assert self.embeddings_strip is not None + xp = array_api_compat.array_namespace(ss) nfnl, nnei = ss.shape[0:2] - shape2 = np.prod(ss.shape[2:]) - ss = ss.reshape(nfnl, nnei, shape2) + shape2 = xp.prod(xp.asarray(ss.shape[2:])) + ss = xp.reshape(ss, (nfnl, nnei, shape2)) # nfnl x nnei x ng gg = self.embeddings_strip[embedding_idx].call(ss) return gg @@ -867,6 +876,7 @@ def call( atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -874,41 +884,50 @@ def call( nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) # nfnl x nnei - nlist = nlist.reshape(nf * nloc, nnei) - nlist = np.where(exclude_mask, nlist, -1) + nlist = xp.reshape(nlist, (nf * nloc, nnei)) + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nfnl x nnei x 4 - dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) + dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) # nfnl x nnei x 1 - sw = sw.reshape(nf * nloc, nnei, 1) + sw = xp.reshape(sw, (nf * nloc, nnei, 1)) # nfnl x tebd_dim - atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, self.tebd_dim) + atype_embd = xp.reshape(atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)) # nfnl x nnei x tebd_dim - atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1)) + atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1)) # nfnl x nnei nlist_mask = nlist != -1 # nfnl x nnei x 1 - sw = np.where(nlist_mask[:, :, None], sw, 0.0) - nlist_masked = np.where(nlist_mask, nlist, 0) - index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0)) + nall = atype_embd_ext.shape[1] + nlist_ = nlist + xp.reshape( + xp.repeat(xp.arange(nf) * nall, nloc * nnei), (nf * nloc, nnei) + ) + nlist_masked = xp.where(nlist_mask, nlist_, xp.full_like(nlist, 0)) + # index = xp.tile(xp.reshape(nlist_masked,(nf, -1, 1)), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim - atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( - nf * nloc, nnei, self.tebd_dim + # atype_embd_nlist = xp.take_along_axis(atype_embd_ext, index, axis=1) + index = xp.reshape(nlist_masked, [-1]) + atype_embd_nlist = xp.take( + xp.reshape(atype_embd_ext, (nf * nall, self.tebd_dim)), index, axis=0 + ) + atype_embd_nlist = xp.reshape( + atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) ) ng = self.neuron[-1] # nfnl x nnei x 4 - rr = dmatrix.reshape(nf * nloc, nnei, 4) - rr = rr * exclude_mask[:, :, None] + rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) + rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nnei x 1 ss = rr[..., 0:1] if self.tebd_input_mode in ["concat"]: if not self.type_one_side: # nfnl x nnei x (1 + 2 * tebd_dim) - ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + ss = xp.concat([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) else: # nfnl x nnei x (1 + tebd_dim) - ss = np.concatenate([ss, atype_embd_nlist], axis=-1) + ss = xp.concat([ss, atype_embd_nlist], axis=-1) # calculate gg # nfnl x nnei x ng gg = self.cal_g(ss, 0) @@ -918,42 +937,47 @@ def call( assert self.embeddings_strip is not None if not self.type_one_side: # nfnl x nnei x (tebd_dim * 2) - tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1) + tt = xp.concat([atype_embd_nlist, atype_embd_nnei], axis=-1) else: # nfnl x nnei x tebd_dim tt = atype_embd_nlist # nfnl x nnei x ng gg_t = self.cal_g_strip(tt, 0) if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1)) # nfnl x nnei x ng gg = gg_s * gg_t + gg_s else: raise NotImplementedError - input_r = rr.reshape(-1, nnei, 4)[:, :, 1:4] / np.maximum( - np.linalg.norm(rr.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True), - 1e-12, + normed = xp.linalg.vector_norm( + xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True + ) + input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( + normed, + xp.full_like(normed, 1e-12), ) gg = self.dpa1_attention( gg, nlist_mask, input_r=input_r, sw=sw ) # shape is [nframes*nloc, self.neei, out_size] # nfnl x ng x 4 - gr = np.einsum("lni,lnj->lij", gg, rr) + # gr = xp.einsum("lni,lnj->lij", gg, rr) + gr = xp.sum(gg[:, :, :, None] * rr[:, :, None, :], axis=1) gr /= self.nnei gr1 = gr[:, : self.axis_neuron, :] # nfnl x ng x ng1 - grrg = np.einsum("lid,ljd->lij", gr, gr1) + # grrg = xp.einsum("lid,ljd->lij", gr, gr1) + grrg = xp.sum(gr[:, :, None, :] * gr1[:, None, :, :], axis=3) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( - GLOBAL_NP_FLOAT_PRECISION + grrg = xp.astype( + xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), coord_ext.dtype ) return ( - grrg.reshape(nf, nloc, self.filter_neuron[-1] * self.axis_neuron), - gg.reshape(nf, nloc, self.nnei, self.filter_neuron[-1]), - dmatrix.reshape(nf, nloc, self.nnei, 4)[..., 1:], - gr[..., 1:].reshape(nf, nloc, self.filter_neuron[-1], 3), - sw, + xp.reshape(grrg, (nf, nloc, self.filter_neuron[-1] * self.axis_neuron)), + xp.reshape(gg, (nf, nloc, self.nnei, self.filter_neuron[-1])), + xp.reshape(dmatrix, (nf, nloc, self.nnei, 4))[..., 1:], + xp.reshape(gr[..., 1:], (nf, nloc, self.filter_neuron[-1], 3)), + xp.reshape(sw, (nf, nloc, nnei, 1)), ) def has_message_passing(self) -> bool: @@ -964,6 +988,77 @@ def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self + data = { + "@class": "DescriptorBlock", + "type": "dpa1", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": obj.attn_mask, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "smooth": obj.smooth, + "type_one_side": obj.type_one_side, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": obj.env_mat.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": np.array(obj["davg"]), + "dstd": np.array(obj["dstd"]), + }, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + obj = cls(**data) + + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) + obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj + class NeighborGatedAttention(NativeOP): def __init__( @@ -1256,18 +1351,23 @@ def __init__( ) def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): + xp = array_api_compat.array_namespace(query, nei_mask) # Linear projection - q, k, v = np.split(self.in_proj(query), 3, axis=-1) + # q, k, v = xp.split(self.in_proj(query), 3, axis=-1) + _query = self.in_proj(query) + q = _query[..., 0 : self.head_dim] + k = _query[..., self.head_dim : self.head_dim * 2] + v = _query[..., self.head_dim * 2 : self.head_dim * 3] # Reshape and normalize # (nf x nloc) x num_heads x nnei x head_dim - q = q.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 + q = xp.permute_dims( + xp.reshape(q, (-1, self.nnei, self.num_heads, self.head_dim)), (0, 2, 1, 3) ) - k = k.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 + k = xp.permute_dims( + xp.reshape(k, (-1, self.nnei, self.num_heads, self.head_dim)), (0, 2, 1, 3) ) - v = v.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 + v = xp.permute_dims( + xp.reshape(v, (-1, self.nnei, self.num_heads, self.head_dim)), (0, 2, 1, 3) ) if self.normalize: q = np_normalize(q, axis=-1) @@ -1276,29 +1376,38 @@ def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): q = q * self.scaling # Attention weights # (nf x nloc) x num_heads x nnei x nnei - attn_weights = q @ k.transpose(0, 1, 3, 2) - nei_mask = nei_mask.reshape(-1, self.nnei) + attn_weights = q @ xp.permute_dims(k, (0, 1, 3, 2)) + nei_mask = xp.reshape(nei_mask, (-1, self.nnei)) if self.smooth: - sw = sw.reshape(-1, 1, self.nnei) + sw = xp.reshape(sw, (-1, 1, self.nnei)) attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ :, :, None, : ] - attnw_shift else: - attn_weights = np.where(nei_mask[:, None, None, :], attn_weights, -np.inf) + attn_weights = xp.where( + nei_mask[:, None, None, :], + attn_weights, + xp.full_like(attn_weights, -xp.inf), + ) attn_weights = np_softmax(attn_weights, axis=-1) - attn_weights = np.where(nei_mask[:, None, :, None], attn_weights, 0.0) + attn_weights = xp.where( + nei_mask[:, None, :, None], attn_weights, xp.zeros_like(attn_weights) + ) if self.smooth: attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] if self.dotr: - angular_weight = (input_r @ input_r.transpose(0, 2, 1)).reshape( - -1, 1, self.nnei, self.nnei + angular_weight = xp.reshape( + input_r @ xp.permute_dims(input_r, (0, 2, 1)), + (-1, 1, self.nnei, self.nnei), ) attn_weights = attn_weights * angular_weight # Output projection # (nf x nloc) x num_heads x nnei x head_dim o = attn_weights @ v # (nf x nloc) x nnei x (num_heads x head_dim) - o = o.transpose(0, 2, 1, 3).reshape(-1, self.nnei, self.hidden_dim) + o = xp.reshape( + xp.permute_dims(o, (0, 2, 1, 3)), (-1, self.nnei, self.hidden_dim) + ) output = self.out_proj(o) return output, attn_weights diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 41f2591279..247e80e926 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -44,33 +44,43 @@ def _make_env_mat( protection: float = 0.0, ): """Make smooth environment matrix.""" + xp = array_api_compat.array_namespace(nlist) nf, nloc, nnei = nlist.shape # nf x nall x 3 - coord = coord.reshape(nf, -1, 3) + coord = xp.reshape(coord, (nf, -1, 3)) mask = nlist >= 0 - nlist = nlist * mask + nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 - index = np.tile(nlist.reshape(nf, -1, 1), (1, 1, 3)) - coord_r = np.take_along_axis(coord, index, 1) + # index = xp.reshape(nlist, (nf, -1, 1)) + # index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) + # coord_r = xp.take_along_axis(coord, xp.tile(index, (1, 1, 3)), 1) + # note: array api doesn't contain take_along_axis until the next version + # reimplement + nall = coord.shape[1] + index = xp.reshape(nlist, (nf * nloc * nnei,)) + xp.repeat( + (xp.arange(nf) * nall), nloc * nnei + ) + coord_ = xp.reshape(coord, (-1, 3)) + coord_r = xp.take(coord_, index, axis=0) # nf x nloc x nnei x 3 - coord_r = coord_r.reshape(nf, nloc, nnei, 3) + coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = coord[:, :nloc].reshape(nf, -1, 1, 3) + coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei - length = np.linalg.norm(diff, axis=-1, keepdims=True) + length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom - length = length + ~np.expand_dims(mask, -1) + length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection) t1 = diff / (length + protection) ** 2 weight = compute_smooth_weight(length, ruct_smth, rcut) - weight = weight * np.expand_dims(mask, -1) + weight = weight * xp.astype(xp.expand_dims(mask, axis=-1), weight.dtype) if radial_only: env_mat = t0 * weight else: - env_mat = np.concatenate([t0, t1], axis=-1) * weight - return env_mat, diff * np.expand_dims(mask, -1), weight + env_mat = xp.concat([t0, t1], axis=-1) * weight + return env_mat, diff * xp.astype(xp.expand_dims(mask, axis=-1), diff.dtype), weight class EnvMat(NativeOP): @@ -122,13 +132,14 @@ def call( switch The value of switch function. shape: nf x nloc x nnei """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape atype = atype_ext[:, :nloc] if davg is not None: - em -= davg[atype] + em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape) if dstd is not None: - em /= dstd[atype] + em /= xp.reshape(xp.take(dstd, xp.reshape(atype, (-1,)), axis=0), em.shape) return em, diff, sw def _call(self, nlist, coord_ext, radial_only): diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index ff668b8153..426ee0b99a 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -4,6 +4,7 @@ Tuple, ) +import array_api_compat import numpy as np @@ -49,8 +50,9 @@ def build_type_exclude_mask( otherwise being 1. """ + xp = array_api_compat.array_namespace(atype) nf, natom = atype.shape - return self.type_mask[atype].reshape(nf, natom) + return xp.reshape(self.type_mask[atype], (nf, natom)) class PairExcludeMask: @@ -68,7 +70,7 @@ def __init__( self.exclude_types.add((tt[0], tt[1])) self.exclude_types.add((tt[1], tt[0])) # ntypes + 1 for nlist masks - self.type_mask = np.array( + type_mask = np.array( [ [ 1 if (tt_i, tt_j) not in self.exclude_types else 0 @@ -79,7 +81,7 @@ def __init__( dtype=np.int32, ) # (ntypes+1 x ntypes+1) - self.type_mask = self.type_mask.reshape([-1]) + self.type_mask = type_mask.reshape([-1]) def get_exclude_types(self): return self.exclude_types @@ -106,23 +108,32 @@ def build_type_exclude_mask( otherwise being 1. """ + xp = array_api_compat.array_namespace(nlist, atype_ext) if len(self.exclude_types) == 0: # safely return 1 if nothing is excluded. - return np.ones_like(nlist, dtype=np.int32) + return xp.ones_like(nlist, dtype=xp.int32) nf, nloc, nnei = nlist.shape nall = atype_ext.shape[1] # add virtual atom of type ntypes. nf x nall+1 - ae = np.concatenate( - [atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 + ae = xp.concat( + [atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 ) - type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1) + type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1) # nf x nloc x nnei - index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei) - type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) + index = xp.reshape( + xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei) + ) + # type_j = xp.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) + index = xp.reshape(index, [-1]) + index += xp.repeat(xp.arange(nf) * (nall + 1), nloc * nnei) + type_j = xp.take(xp.reshape(ae, [-1]), index, axis=0) + type_j = xp.reshape(type_j, (nf, nloc, nnei)) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) - type_ij = type_ij.reshape(nf, nloc * nnei) - mask = self.type_mask[type_ij].reshape(nf, nloc, nnei) + type_ij = xp.reshape(type_ij, (nf, nloc * nnei)) + mask = xp.reshape( + xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei) + ) return mask def __contains__(self, item): diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 22e85c9890..66104cb01c 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -148,15 +148,18 @@ def deserialize(cls, data: dict) -> "NativeLayer": num_out, **data, ) - obj.w, obj.b, obj.idt = ( + w, b, idt = ( variables["w"], variables.get("b", None), variables.get("idt", None), ) - if obj.b is not None: - obj.b = obj.b.ravel() - if obj.idt is not None: - obj.idt = obj.idt.ravel() + if b is not None: + b = b.ravel() + if idt is not None: + idt = idt.ravel() + obj.w = w + obj.b = b + obj.idt = idt obj.check_shape_consistency() return obj @@ -177,8 +180,11 @@ def check_type_consistency(self): def check_var(var): if var is not None: + # array api standard doesn't provide a API to get the dtype name + # this is really hacked + dtype_name = str(var.dtype).split(".")[-1] # assertion "float64" == "double" would fail - assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] + assert PRECISION_DICT[dtype_name] is PRECISION_DICT[precision] check_var(self.w) check_var(self.b) @@ -251,7 +257,7 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.resnet and self.w.shape[1] == self.w.shape[0]: y += x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += xp.concatenate([x, x], axis=-1) + y += xp.concat([x, x], axis=-1) return y @@ -362,10 +368,11 @@ def __init__( precision=precision, seed=seed, ) - self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] + xp = array_api_compat.array_namespace(self.w, self.b) + self.w = xp.squeeze(self.w, 0) # keep the weight shape to be [num_in] if self.uni_init: - self.w = np.ones_like(self.w) - self.b = np.zeros_like(self.b) + self.w = xp.ones_like(self.w) + self.b = xp.zeros_like(self.b) # only to keep consistent with other backends self.trainable = trainable @@ -378,8 +385,8 @@ def serialize(self) -> dict: The serialized layer. """ data = { - "w": self.w, - "b": self.b, + "w": np.array(self.w), + "b": np.array(self.b), } return { "@class": "LayerNorm", @@ -473,11 +480,12 @@ def call(self, x: np.ndarray) -> np.ndarray: @staticmethod def layer_norm_numpy(x, shape, weight=None, bias=None, eps=1e-5): + xp = array_api_compat.array_namespace(x) # mean and variance - mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) - var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + mean = xp.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + var = xp.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) # normalize - x_normalized = (x - mean) / np.sqrt(var + eps) + x_normalized = (x - mean) / xp.sqrt(var + eps) # shift and scale if weight is not None and bias is not None: x_normalized = x_normalized * weight + bias diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index c935377e6a..43a5c7e4f5 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from .region import ( @@ -90,34 +91,36 @@ def build_neighbor_list( For virtual atoms all neighboring positions are filled with -1. """ + xp = array_api_compat.array_namespace(coord, atype) batch_size = coord.shape[0] - coord = coord.reshape(batch_size, -1) + coord = xp.reshape(coord, (batch_size, -1)) nall = coord.shape[1] // 3 # fill virtual atoms with large coords so they are not neighbors of any # real atom. if coord.size > 0: - xmax = np.max(coord) + 2.0 * rcut + xmax = xp.max(coord) + 2.0 * rcut else: xmax = 2.0 * rcut # nf x nall is_vir = atype < 0 - coord1 = np.where( - is_vir[:, :, None], xmax, coord.reshape(batch_size, nall, 3) - ).reshape(batch_size, nall * 3) + coord1 = xp.where( + is_vir[:, :, None], xmax, xp.reshape(coord, (batch_size, nall, 3)) + ) + coord1 = xp.reshape(coord1, (batch_size, nall * 3)) if isinstance(sel, int): sel = [sel] nsel = sum(sel) coord0 = coord1[:, : nloc * 3] diff = ( - coord1.reshape([batch_size, -1, 3])[:, None, :, :] - - coord0.reshape([batch_size, -1, 3])[:, :, None, :] + xp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :] + - xp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :] ) assert list(diff.shape) == [batch_size, nloc, nall, 3] - rr = np.linalg.norm(diff, axis=-1) + rr = xp.linalg.vector_norm(diff, axis=-1) # if central atom has two zero distances, sorting sometimes can not exclude itself - rr -= np.eye(nloc, nall, dtype=diff.dtype)[np.newaxis, :, :] - nlist = np.argsort(rr, axis=-1) - rr = np.sort(rr, axis=-1) + rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :] + nlist = xp.argsort(rr, axis=-1) + rr = xp.sort(rr, axis=-1) rr = rr[:, :, 1:] nlist = nlist[:, :, 1:] nnei = rr.shape[2] @@ -125,16 +128,20 @@ def build_neighbor_list( rr = rr[:, :, :nsel] nlist = nlist[:, :, :nsel] else: - rr = np.concatenate( - [rr, np.ones([batch_size, nloc, nsel - nnei]) + rcut], # pylint: disable=no-explicit-dtype + rr = xp.concatenate( + [rr, xp.ones([batch_size, nloc, nsel - nnei]) + rcut], # pylint: disable=no-explicit-dtype axis=-1, ) - nlist = np.concatenate( - [nlist, np.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + nlist = xp.concatenate( + [nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], axis=-1, ) assert list(nlist.shape) == [batch_size, nloc, nsel] - nlist = np.where(np.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist) + nlist = xp.where( + xp.logical_or((rr > rcut), is_vir[:, :nloc, None]), + xp.full_like(nlist, -1), + nlist, + ) if distinguish_types: return nlist_distinguish_types(nlist, atype, sel) @@ -151,23 +158,24 @@ def nlist_distinguish_types( distinguish atom types. """ + xp = array_api_compat.array_namespace(nlist, atype) nf, nloc, _ = nlist.shape ret_nlist = [] - tmp_atype = np.tile(atype[:, None], [1, nloc, 1]) + tmp_atype = xp.tile(atype[:, None], [1, nloc, 1]) mask = nlist == -1 tnlist_0 = nlist.copy() tnlist_0[mask] = 0 - tnlist = np.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() - tnlist = np.where(mask, -1, tnlist) + tnlist = xp.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() + tnlist = xp.where(mask, -1, tnlist) snsel = tnlist.shape[2] for ii, ss in enumerate(sel): - pick_mask = (tnlist == ii).astype(np.int32) - sorted_indices = np.argsort(-pick_mask, kind="stable", axis=-1) - pick_mask_sorted = -np.sort(-pick_mask, axis=-1) - inlist = np.take_along_axis(nlist, sorted_indices, axis=2) - inlist = np.where(~pick_mask_sorted.astype(bool), -1, inlist) - ret_nlist.append(np.split(inlist, [ss, snsel - ss], axis=-1)[0]) - ret = np.concatenate(ret_nlist, axis=-1) + pick_mask = (tnlist == ii).astype(xp.int32) + sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask_sorted = -xp.sort(-pick_mask, axis=-1) + inlist = xp.take_along_axis(nlist, sorted_indices, axis=2) + inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist) + ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0]) + ret = xp.concatenate(ret_nlist, axis=-1) return ret @@ -265,36 +273,46 @@ def extend_coord_with_ghosts( maping extended index to the local index """ + xp = array_api_compat.array_namespace(coord, atype, cell) nf, nloc = atype.shape - aidx = np.tile(np.arange(nloc)[np.newaxis, :], (nf, 1)) # pylint: disable=no-explicit-dtype + aidx = xp.tile(xp.arange(nloc)[xp.newaxis, :], (nf, 1)) # pylint: disable=no-explicit-dtype if cell is None: nall = nloc - extend_coord = coord.copy() - extend_atype = atype.copy() - extend_aidx = aidx.copy() + extend_coord = coord + extend_atype = atype + extend_aidx = aidx else: - coord = coord.reshape((nf, nloc, 3)) - cell = cell.reshape((nf, 3, 3)) + coord = xp.reshape(coord, (nf, nloc, 3)) + cell = xp.reshape(cell, (nf, 3, 3)) to_face = to_face_distance(cell) - nbuff = np.ceil(rcut / to_face).astype(int) - nbuff = np.max(nbuff, axis=0) - xi = np.arange(-nbuff[0], nbuff[0] + 1, 1) # pylint: disable=no-explicit-dtype - yi = np.arange(-nbuff[1], nbuff[1] + 1, 1) # pylint: disable=no-explicit-dtype - zi = np.arange(-nbuff[2], nbuff[2] + 1, 1) # pylint: disable=no-explicit-dtype - xyz = np.outer(xi, np.array([1, 0, 0]))[:, np.newaxis, np.newaxis, :] - xyz = xyz + np.outer(yi, np.array([0, 1, 0]))[np.newaxis, :, np.newaxis, :] - xyz = xyz + np.outer(zi, np.array([0, 0, 1]))[np.newaxis, np.newaxis, :, :] - xyz = xyz.reshape(-1, 3) - shift_idx = xyz[np.argsort(np.linalg.norm(xyz, axis=1))] + nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64) + nbuff = xp.max(nbuff, axis=0) + xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1) # pylint: disable=no-explicit-dtype + yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1) # pylint: disable=no-explicit-dtype + zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1) # pylint: disable=no-explicit-dtype + xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :] + xyz = ( + xyz + + xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :] + ) + xyz = ( + xyz + + xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :] + ) + xyz = xp.reshape(xyz, (-1, 3)) + xyz = xp.astype(xyz, coord.dtype) + shift_idx = xp.take(xyz, xp.argsort(xp.linalg.vector_norm(xyz, axis=1)), axis=0) ns, _ = shift_idx.shape nall = ns * nloc - shift_vec = np.einsum("sd,fdk->fsk", shift_idx, cell) + # shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell) + shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1])) + shift_vec = xp.permute_dims(shift_vec, (1, 0, 2)) extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] - extend_atype = np.tile(atype[:, :, np.newaxis], (1, ns, 1)) - extend_aidx = np.tile(aidx[:, :, np.newaxis], (1, ns, 1)) + extend_atype = xp.tile(atype[:, :, xp.newaxis], (1, ns, 1)) + extend_aidx = xp.tile(aidx[:, :, xp.newaxis], (1, ns, 1)) return ( - extend_coord.reshape((nf, nall * 3)), - extend_atype.reshape((nf, nall)), - extend_aidx.reshape((nf, nall)), + xp.reshape(extend_coord, (nf, nall * 3)), + xp.reshape(extend_atype, (nf, nall)), + xp.reshape(extend_aidx, (nf, nall)), ) diff --git a/deepmd/dpmodel/utils/region.py b/deepmd/dpmodel/utils/region.py index ddbc4b29b8..8102020827 100644 --- a/deepmd/dpmodel/utils/region.py +++ b/deepmd/dpmodel/utils/region.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np @@ -21,8 +22,9 @@ def phys2inter( the internal coordinates """ - rec_cell = np.linalg.inv(cell) - return np.matmul(coord, rec_cell) + xp = array_api_compat.array_namespace(coord, cell) + rec_cell = xp.linalg.inv(cell) + return xp.matmul(coord, rec_cell) def inter2phys( @@ -44,7 +46,8 @@ def inter2phys( the physical coordinates """ - return np.matmul(coord, cell) + xp = array_api_compat.array_namespace(coord, cell) + return xp.matmul(coord, cell) def normalize_coord( @@ -66,8 +69,9 @@ def normalize_coord( wrapped coordinates of shape [*, na, 3]. """ + xp = array_api_compat.array_namespace(coord, cell) icoord = phys2inter(coord, cell) - icoord = np.remainder(icoord, 1.0) + icoord = xp.remainder(icoord, 1.0) return inter2phys(icoord, cell) @@ -87,17 +91,19 @@ def to_face_distance( the to face distances of shape [*, 3] """ + xp = array_api_compat.array_namespace(cell) cshape = cell.shape - dist = b_to_face_distance(cell.reshape([-1, 3, 3])) - return dist.reshape(list(cshape[:-2]) + [3]) # noqa:RUF005 + dist = b_to_face_distance(xp.reshape(cell, [-1, 3, 3])) + return xp.reshape(dist, list(cshape[:-2]) + [3]) # noqa:RUF005 def b_to_face_distance(cell): - volume = np.linalg.det(cell) - c_yz = np.cross(cell[:, 1], cell[:, 2], axis=-1) - _h2yz = volume / np.linalg.norm(c_yz, axis=-1) - c_zx = np.cross(cell[:, 2], cell[:, 0], axis=-1) - _h2zx = volume / np.linalg.norm(c_zx, axis=-1) - c_xy = np.cross(cell[:, 0], cell[:, 1], axis=-1) - _h2xy = volume / np.linalg.norm(c_xy, axis=-1) - return np.stack([_h2yz, _h2zx, _h2xy], axis=1) + xp = array_api_compat.array_namespace(cell) + volume = xp.linalg.det(cell) + c_yz = xp.linalg.cross(cell[:, 1, ...], cell[:, 2, ...], axis=-1) + _h2yz = volume / xp.linalg.vector_norm(c_yz, axis=-1) + c_zx = xp.linalg.cross(cell[:, 2, ...], cell[:, 0, ...], axis=-1) + _h2zx = volume / xp.linalg.vector_norm(c_zx, axis=-1) + c_xy = xp.linalg.cross(cell[:, 0, ...], cell[:, 1, ...], axis=-1) + _h2xy = volume / xp.linalg.vector_norm(c_xy, axis=-1) + return xp.stack([_h2yz, _h2zx, _h2xy], axis=1) diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index e11c415cfd..dd7d612e00 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -107,7 +107,7 @@ def call(self) -> np.ndarray: embed = self.embedding_net(self.econf_tebd) if self.padding: embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) - embed = xp.concatenate([embed, embed_pad], axis=0) + embed = xp.concat([embed, embed_pad], axis=0) return embed @classmethod diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 550b168b29..8c3860cf39 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Union, + Optional, overload, ) @@ -19,7 +19,7 @@ def to_jax_array(array: np.ndarray) -> jnp.ndarray: ... def to_jax_array(array: None) -> None: ... -def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]: +def to_jax_array(array: Optional[np.ndarray]) -> Optional[jnp.ndarray]: """Convert a numpy array to a JAX array. Parameters diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py new file mode 100644 index 0000000000..73ef6055e5 --- /dev/null +++ b/deepmd/jax/descriptor/dpa1.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptBlockSeAtten as DescrptBlockSeAttenDP +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.dpmodel.descriptor.dpa1 import GatedAttentionLayer as GatedAttentionLayerDP +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttention as NeighborGatedAttentionDP, +) +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, +) +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + LayerNorm, + NativeLayer, + NetworkCollection, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +class GatedAttentionLayer(GatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"in_proj", "out_proj"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layer": + value = GatedAttentionLayer.deserialize(value.serialize()) + elif name == "attn_layer_norm": + value = LayerNorm.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class NeighborGatedAttention(NeighborGatedAttentionDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layers": + value = [ + NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value + ] + return super().__setattr__(name, value) + + +class DescrptBlockSeAtten(DescrptBlockSeAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "dpa1_attention": + value = NeighborGatedAttention.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class DescrptDPA1(DescrptDPA1DP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_atten": + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py new file mode 100644 index 0000000000..6519648514 --- /dev/null +++ b/deepmd/jax/utils/exclude_mask.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP +from deepmd.jax.common import ( + to_jax_array, +) + + +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"type_mask"}: + value = to_jax_array(value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 629b51b8cd..6517573b38 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + ClassVar, + Dict, ) from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( make_embedding_network, make_fitting_network, @@ -27,3 +31,15 @@ def __setattr__(self, name: str, value: Any) -> None: NativeNet = make_multilayer_network(NativeLayer, NativeOP) EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) + + +class NetworkCollection(NetworkCollectionDP): + NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/source/tests/array_api_strict/__init__.py b/source/tests/array_api_strict/__init__.py new file mode 100644 index 0000000000..27785c2fd5 --- /dev/null +++ b/source/tests/array_api_strict/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Synchronize with deepmd.jax for test purpose only.""" diff --git a/source/tests/array_api_strict/common.py b/source/tests/array_api_strict/common.py new file mode 100644 index 0000000000..28f67a97f6 --- /dev/null +++ b/source/tests/array_api_strict/common.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import array_api_strict +import numpy as np + + +def to_array_api_strict_array(array: Optional[np.ndarray]): + """Convert a numpy array to a JAX array. + + Parameters + ---------- + array : np.ndarray + The numpy array to convert. + + Returns + ------- + jnp.ndarray + The JAX tensor. + """ + if array is None: + return None + return array_api_strict.asarray(array) diff --git a/source/tests/array_api_strict/descriptor/__init__.py b/source/tests/array_api_strict/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/array_api_strict/descriptor/dpa1.py b/source/tests/array_api_strict/descriptor/dpa1.py new file mode 100644 index 0000000000..ebd688e303 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/dpa1.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptBlockSeAtten as DescrptBlockSeAttenDP +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.dpmodel.descriptor.dpa1 import GatedAttentionLayer as GatedAttentionLayerDP +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttention as NeighborGatedAttentionDP, +) +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, +) + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + LayerNorm, + NativeLayer, + NetworkCollection, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) + + +class GatedAttentionLayer(GatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"in_proj", "out_proj"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layer": + value = GatedAttentionLayer.deserialize(value.serialize()) + elif name == "attn_layer_norm": + value = LayerNorm.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class NeighborGatedAttention(NeighborGatedAttentionDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layers": + value = [ + NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value + ] + return super().__setattr__(name, value) + + +class DescrptBlockSeAtten(DescrptBlockSeAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "dpa1_attention": + value = NeighborGatedAttention.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class DescrptDPA1(DescrptDPA1DP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_atten": + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/utils/__init__.py b/source/tests/array_api_strict/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/array_api_strict/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/array_api_strict/utils/exclude_mask.py b/source/tests/array_api_strict/utils/exclude_mask.py new file mode 100644 index 0000000000..06f2e94b52 --- /dev/null +++ b/source/tests/array_api_strict/utils/exclude_mask.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP + +from ..common import ( + to_array_api_strict_array, +) + + +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"type_mask"}: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/utils/network.py b/source/tests/array_api_strict/utils/network.py new file mode 100644 index 0000000000..8a1324a8da --- /dev/null +++ b/source/tests/array_api_strict/utils/network.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + ClassVar, + Dict, +) + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) + +from ..common import ( + to_array_api_strict_array, +) + + +class NativeLayer(NativeLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"}: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) + + +NativeNet = make_multilayer_network(NativeLayer, NativeOP) +EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) + + +class NetworkCollection(NetworkCollectionDP): + NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/source/tests/array_api_strict/utils/type_embed.py b/source/tests/array_api_strict/utils/type_embed.py new file mode 100644 index 0000000000..7551279002 --- /dev/null +++ b/source/tests/array_api_strict/utils/type_embed.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.network import ( + EmbeddingNet, +) + + +class TypeEmbedNet(TypeEmbedNetDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"econf_tebd"}: + value = to_array_api_strict_array(value) + if name in {"embedding_net"}: + value = EmbeddingNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/common/dpmodel/test_descriptor_dpa1.py b/source/tests/common/dpmodel/test_descriptor_dpa1.py index 317f4c3d3d..f441895f15 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa1.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa1.py @@ -36,3 +36,22 @@ def test_self_consistency( mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) for ii in [0, 1, 4]: np.testing.assert_allclose(mm0[ii], mm1[ii]) + + def test_multiple_frames(self): + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + em0 = DescrptDPA1(self.rcut, self.rcut_smth, self.sel, ntypes=2) + em0.davg = davg + em0.dstd = dstd + two_coord_ext = np.concatenate([self.coord_ext, self.coord_ext], axis=0) + two_atype_ext = np.concatenate([self.atype_ext, self.atype_ext], axis=0) + two_nlist = np.concatenate([self.nlist, self.nlist], axis=0) + + mm0 = em0.call(two_coord_ext, two_atype_ext, two_nlist) + for ii in [0, 1, 4]: + np.testing.assert_allclose(mm0[ii][0], mm0[ii][2], err_msg=f"{ii} 0~2") + np.testing.assert_allclose(mm0[ii][1], mm0[ii][3], err_msg=f"{ii} 1~3") diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index e8873e528a..d7ece13806 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -10,6 +10,9 @@ from enum import ( Enum, ) +from importlib.util import ( + find_spec, +) from typing import ( Any, Callable, @@ -36,6 +39,7 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() INSTALLED_JAX = Backend.get_backend("jax")().is_available() +INSTALLED_ARRAY_API_STRICT = find_spec("array_api_strict") is not None if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT): raise ImportError("TensorFlow or PyTorch should be tested in the CI") @@ -59,6 +63,7 @@ "INSTALLED_TF", "INSTALLED_PT", "INSTALLED_JAX", + "INSTALLED_ARRAY_API_STRICT", ] @@ -75,6 +80,7 @@ class CommonTest(ABC): """PyTorch model class.""" jax_class: ClassVar[Optional[type]] """JAX model class.""" + array_api_strict_class: ClassVar[Optional[type]] args: ClassVar[Optional[Union[Argument, List[Argument]]]] """Arguments that maps to the `data`.""" skip_dp: ClassVar[bool] = False @@ -86,6 +92,8 @@ class CommonTest(ABC): # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" + skip_array_api_strict: ClassVar[bool] = True + """Whether to skip the array_api_strict model.""" rtol = 1e-10 """Relative tolerance for comparing the return value. Override for float32.""" atol = 1e-10 @@ -166,6 +174,16 @@ def eval_jax(self, jax_obj: Any) -> Any: """ raise NotImplementedError("Not implemented") + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + """Evaluate the return value of array_api_strict. + + Parameters + ---------- + array_api_strict_obj : Any + The object of array_api_strict + """ + raise NotImplementedError("Not implemented") + class RefBackend(Enum): """Reference backend.""" @@ -173,6 +191,7 @@ class RefBackend(Enum): DP = 2 PT = 3 JAX = 5 + ARRAY_API_STRICT = 6 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> Tuple[np.ndarray, ...]: @@ -238,6 +257,11 @@ def get_jax_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_array_api_strict_ret_serialization_from_cls(self, obj): + ret = self.eval_array_api_strict(obj) + data = obj.serialize() + return ret, data + def get_reference_backend(self): """Get the reference backend. @@ -251,6 +275,8 @@ def get_reference_backend(self): return self.RefBackend.PT if not self.skip_jax: return self.RefBackend.JAX + if not self.skip_array_api_strict: + return self.RefBackend.ARRAY_API_STRICT raise ValueError("No available reference") def get_reference_ret_serialization(self, ref: RefBackend): @@ -264,6 +290,12 @@ def get_reference_ret_serialization(self, ref: RefBackend): if ref == self.RefBackend.PT: obj = self.init_backend_cls(self.pt_class) return self.get_pt_ret_serialization_from_cls(obj) + if ref == self.RefBackend.JAX: + obj = self.init_backend_cls(self.jax_class) + return self.get_jax_ret_serialization_from_cls(obj) + if ref == self.RefBackend.ARRAY_API_STRICT: + obj = self.init_backend_cls(self.array_api_strict_class) + return self.get_array_api_ret_serialization_from_cls(obj) raise ValueError("No available reference") def test_tf_consistent_with_ref(self): @@ -418,6 +450,40 @@ def test_jax_self_consistent(self): else: self.assertEqual(rr1, rr2) + def test_array_api_strict_consistent_with_ref(self): + """Test whether array_api_strict and reference are consistent.""" + if self.skip_array_api_strict: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.ARRAY_API_STRICT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + array_api_strict_obj = self.array_api_strict_class.deserialize(data1) + ret2 = self.eval_array_api_strict(array_api_strict_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.ARRAY_API_STRICT) + data2 = array_api_strict_obj.serialize() + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_array_api_strict_self_consistent(self): + """Test whether array_api_strict is self consistent.""" + if self.skip_array_api_strict: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.array_api_strict_class) + ret1, data1 = self.get_array_api_strict_ret_serialization_from_cls(obj1) + obj1 = self.array_api_strict_class.deserialize(data1) + ret2, data2 = self.get_array_api_strict_ret_serialization_from_cls(obj1) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def tearDown(self) -> None: """Clear the TF session.""" if not self.skip_tf: diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 74fc3d9b07..e0ca30c799 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -3,6 +3,8 @@ Any, ) +import numpy as np + from deepmd.common import ( make_default_mesh, ) @@ -12,6 +14,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, ) @@ -29,6 +33,12 @@ GLOBAL_TF_FLOAT_PRECISION, tf, ) +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict class DescriptorTest: @@ -99,3 +109,56 @@ def eval_pt_descriptor( x.detach().cpu().numpy() if torch.is_tensor(x) else x for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] + + def eval_jax_descriptor( + self, jax_obj: Any, natoms, coords, atype, box, mixed_types: bool = False + ) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + jnp.array(coords).reshape(1, -1, 3), + jnp.array(atype).reshape(1, -1), + jnp.array(box).reshape(1, 3, 3), + jax_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + jax_obj.get_rcut(), + jax_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + np.asarray(x) if isinstance(x, jnp.ndarray) else x + for x in jax_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + ] + + def eval_array_api_strict_descriptor( + self, + array_api_strict_obj: Any, + natoms, + coords, + atype, + box, + mixed_types: bool = False, + ) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + array_api_strict.asarray(coords.reshape(1, -1, 3)), + array_api_strict.asarray(atype.reshape(1, -1)), + array_api_strict.asarray(box.reshape(1, 3, 3)), + array_api_strict_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + array_api_strict_obj.get_rcut(), + array_api_strict_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + np.asarray(x) if hasattr(x, "__array_namespace__") else x + for x in array_api_strict_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping + ) + ] diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 0f44ecaae1..e01960cd21 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -17,6 +17,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -34,6 +36,14 @@ from deepmd.tf.descriptor.se_atten import DescrptDPA1Compat as DescrptDPA1TF else: DescrptDPA1TF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1JAX +else: + DescrptDPA1JAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1Strict +else: + DescriptorDPA1Strict = None from deepmd.utils.argcheck import ( descrpt_se_atten_args, ) @@ -184,6 +194,69 @@ def skip_dp(self) -> bool: temperature, ) + @property + def skip_jax(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_JAX or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + + @property + def skip_array_api_strict(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return ( + not INSTALLED_ARRAY_API_STRICT + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + ) + @property def skip_tf(self) -> bool: ( @@ -227,6 +300,9 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA1TF dp_class = DescrptDPA1DP pt_class = DescrptDPA1PT + jax_class = DescriptorDPA1JAX + array_api_strict_class = DescriptorDPA1Strict + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -314,6 +390,26 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: return (ret[0],) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index c66ef0fbaa..696730d44d 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -13,6 +13,7 @@ ) from .common import ( + INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, @@ -38,6 +39,12 @@ from deepmd.jax.utils.type_embed import TypeEmbedNet as TypeEmbedNetJAX else: TypeEmbedNetJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ..array_api_strict.utils.type_embed import TypeEmbedNet as TypeEmbedNetStrict +else: + TypeEmbedNetStrict = None @parameterized( @@ -72,6 +79,7 @@ def data(self) -> dict: dp_class = TypeEmbedNetDP pt_class = TypeEmbedNetPT jax_class = TypeEmbedNetJAX + array_api_strict_class = TypeEmbedNetStrict args = type_embedding_args() skip_jax = not INSTALLED_JAX @@ -121,6 +129,13 @@ def eval_jax(self, jax_obj: Any) -> Any: raise ValueError("Output is numpy array") return [np.array(x) if isinstance(x, jnp.ndarray) else x for x in (out,)] + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + out = array_api_strict_obj() + return [ + np.array(x) if isinstance(x, array_api_strict.ndarray) else x + for x in (out,) + ] + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: return (ret[0],) From 801f36e0d1fcb551aa1c2ee188af9758fac06e87 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 23 Sep 2024 21:58:52 -0400 Subject: [PATCH 02/14] Apply suggestions from code review Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- source/tests/consistent/common.py | 2 +- source/tests/consistent/descriptor/test_dpa1.py | 2 +- source/tests/consistent/test_type_embedding.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index d7ece13806..a19528fbd7 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -295,7 +295,7 @@ def get_reference_ret_serialization(self, ref: RefBackend): return self.get_jax_ret_serialization_from_cls(obj) if ref == self.RefBackend.ARRAY_API_STRICT: obj = self.init_backend_cls(self.array_api_strict_class) - return self.get_array_api_ret_serialization_from_cls(obj) + return self.get_array_api_strict_ret_serialization_from_cls(obj) raise ValueError("No available reference") def test_tf_consistent_with_ref(self): diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index e01960cd21..3f103c8983 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -39,7 +39,7 @@ if INSTALLED_JAX: from deepmd.jax.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1JAX else: - DescrptDPA1JAX = None + DescriptorDPA1JAX = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1Strict else: diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 696730d44d..cc2f3b96f3 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -82,6 +82,7 @@ def data(self) -> dict: array_api_strict_class = TypeEmbedNetStrict args = type_embedding_args() skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT @property def addtional_data(self) -> dict: From 298ed95ac0947d678c50ae823d8b5acbf5df29b7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 23 Sep 2024 22:49:17 -0400 Subject: [PATCH 03/14] Fix test_type_embedding.py Signed-off-by: Jinzhe Zeng --- source/tests/consistent/test_type_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index cc2f3b96f3..41f63773fb 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -133,7 +133,7 @@ def eval_jax(self, jax_obj: Any) -> Any: def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: out = array_api_strict_obj() return [ - np.array(x) if isinstance(x, array_api_strict.ndarray) else x + np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,) ] From b3f0dd78defebfca4dfc23fc6431f478bbb7c198 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 02:50:03 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/consistent/test_type_embedding.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 41f63773fb..278dc53efe 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -40,8 +40,6 @@ else: TypeEmbedNetJAX = object if INSTALLED_ARRAY_API_STRICT: - import array_api_strict - from ..array_api_strict.utils.type_embed import TypeEmbedNet as TypeEmbedNetStrict else: TypeEmbedNetStrict = None @@ -133,8 +131,7 @@ def eval_jax(self, jax_obj: Any) -> Any: def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: out = array_api_strict_obj() return [ - np.asarray(x) if hasattr(x, "__array_namespace__") else x - for x in (out,) + np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,) ] def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: From be6a8dca0825f394710ef21089c5efeacd0392cf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 24 Sep 2024 01:05:41 -0400 Subject: [PATCH 05/14] cell may be none --- deepmd/dpmodel/utils/nlist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 43a5c7e4f5..03b4651553 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -273,7 +273,7 @@ def extend_coord_with_ghosts( maping extended index to the local index """ - xp = array_api_compat.array_namespace(coord, atype, cell) + xp = array_api_compat.array_namespace(coord, atype) nf, nloc = atype.shape aidx = xp.tile(xp.arange(nloc)[xp.newaxis, :], (nf, 1)) # pylint: disable=no-explicit-dtype if cell is None: From 7de9ee328c29acf68ea8300cb8c6b9f3f6726943 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 24 Sep 2024 02:28:14 -0400 Subject: [PATCH 06/14] fix nlist_masked for multiple frames Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/dpa1.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index d16c87d37c..5cf73e79d9 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -901,10 +901,11 @@ def call( # nfnl x nnei x 1 sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0)) nall = atype_embd_ext.shape[1] - nlist_ = nlist + xp.reshape( + nfidx = xp.reshape( xp.repeat(xp.arange(nf) * nall, nloc * nnei), (nf * nloc, nnei) ) - nlist_masked = xp.where(nlist_mask, nlist_, xp.full_like(nlist, 0)) + nlist_ = nlist + nfidx + nlist_masked = xp.where(nlist_mask, nlist_, nfidx) # index = xp.tile(xp.reshape(nlist_masked,(nf, -1, 1)), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim # atype_embd_nlist = xp.take_along_axis(atype_embd_ext, index, axis=1) From d65206f8e6bae8ea7a2b9f6f6f55a991ac4d7674 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 25 Sep 2024 02:28:05 -0400 Subject: [PATCH 07/14] use a Python implementation of take_along_axis Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 37 ++++++++++++++++++++++++++++ deepmd/dpmodel/descriptor/dpa1.py | 18 +++++--------- deepmd/dpmodel/utils/env_mat.py | 14 +++-------- deepmd/dpmodel/utils/exclude_mask.py | 10 ++++---- deepmd/dpmodel/utils/nlist.py | 10 +++++--- 5 files changed, 58 insertions(+), 31 deletions(-) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e4af2ad627..9d966b798c 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Utilities for the array API.""" +import array_api_compat + def support_array_api(version: str) -> callable: """Mark a function as supporting the specific version of the array API. @@ -27,3 +29,38 @@ def set_version(func: callable) -> callable: return func return set_version + + +# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816 +# but it hasn't been released yet +# below is a pure Python implementation of take_along_axis +# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595 +def xp_swapaxes(a, axis1, axis2): + xp = array_api_compat.array_namespace(a) + axes = list(range(a.ndim)) + axes[axis1], axes[axis2] = axes[axis2], axes[axis1] + a = xp.permute_dims(a, axes) + return a + + +def xp_take_along_axis(arr, indices, axis): + xp = array_api_compat.array_namespace(arr) + arr = xp_swapaxes(arr, axis, -1) + indices = xp_swapaxes(indices, axis, -1) + + m = arr.shape[-1] + n = indices.shape[-1] + + shape = list(arr.shape) + shape.pop(-1) + shape = [*shape, n] + + arr = xp.reshape(arr, (-1,)) + indices = xp.reshape(indices, (-1, n)) + + offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] + indices = xp.reshape(offset + indices, (-1,)) + + out = xp.take(arr, indices) + out = xp.reshape(out, shape) + return xp_swapaxes(out, axis, -1) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 5cf73e79d9..d015fd349d 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -16,6 +16,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -900,19 +903,10 @@ def call( nlist_mask = nlist != -1 # nfnl x nnei x 1 sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0)) - nall = atype_embd_ext.shape[1] - nfidx = xp.reshape( - xp.repeat(xp.arange(nf) * nall, nloc * nnei), (nf * nloc, nnei) - ) - nlist_ = nlist + nfidx - nlist_masked = xp.where(nlist_mask, nlist_, nfidx) - # index = xp.tile(xp.reshape(nlist_masked,(nf, -1, 1)), (1, 1, self.tebd_dim)) + nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) + index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim - # atype_embd_nlist = xp.take_along_axis(atype_embd_ext, index, axis=1) - index = xp.reshape(nlist_masked, [-1]) - atype_embd_nlist = xp.take( - xp.reshape(atype_embd_ext, (nf * nall, self.tebd_dim)), index, axis=0 - ) + atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) atype_embd_nlist = xp.reshape( atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) ) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 247e80e926..f4bc333a03 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -12,6 +12,7 @@ ) from deepmd.dpmodel.array_api import ( support_array_api, + xp_take_along_axis, ) @@ -51,17 +52,8 @@ def _make_env_mat( mask = nlist >= 0 nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 - # index = xp.reshape(nlist, (nf, -1, 1)) - # index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) - # coord_r = xp.take_along_axis(coord, xp.tile(index, (1, 1, 3)), 1) - # note: array api doesn't contain take_along_axis until the next version - # reimplement - nall = coord.shape[1] - index = xp.reshape(nlist, (nf * nloc * nnei,)) + xp.repeat( - (xp.arange(nf) * nall), nloc * nnei - ) - coord_ = xp.reshape(coord, (-1, 3)) - coord_r = xp.take(coord_, index, axis=0) + index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) + coord_r = xp_take_along_axis(coord, index, 1) # nf x nloc x nnei x 3 coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 426ee0b99a..a8e8dc7ef3 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -7,6 +7,10 @@ import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) + class AtomExcludeMask: """Computes the type exclusion mask for atoms.""" @@ -123,11 +127,7 @@ def build_type_exclude_mask( index = xp.reshape( xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei) ) - # type_j = xp.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) - index = xp.reshape(index, [-1]) - index += xp.repeat(xp.arange(nf) * (nall + 1), nloc * nnei) - type_j = xp.take(xp.reshape(ae, [-1]), index, axis=0) - type_j = xp.reshape(type_j, (nf, nloc, nnei)) + type_j = xp_take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) type_ij = xp.reshape(type_ij, (nf, nloc * nnei)) diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 03b4651553..7aac947712 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -9,6 +9,10 @@ import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) + from .region import ( normalize_coord, to_face_distance, @@ -165,17 +169,17 @@ def nlist_distinguish_types( mask = nlist == -1 tnlist_0 = nlist.copy() tnlist_0[mask] = 0 - tnlist = xp.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() + tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() tnlist = xp.where(mask, -1, tnlist) snsel = tnlist.shape[2] for ii, ss in enumerate(sel): pick_mask = (tnlist == ii).astype(xp.int32) sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1) pick_mask_sorted = -xp.sort(-pick_mask, axis=-1) - inlist = xp.take_along_axis(nlist, sorted_indices, axis=2) + inlist = xp_take_along_axis(nlist, sorted_indices, axis=2) inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist) ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0]) - ret = xp.concatenate(ret_nlist, axis=-1) + ret = xp.concat(ret_nlist, axis=-1) return ret From aff0b42123fab7caf0d89eaa7da48f856a869661 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 25 Sep 2024 03:27:41 -0400 Subject: [PATCH 08/14] fix zero shape Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/array_api.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 9d966b798c..360df78a7b 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -56,7 +56,10 @@ def xp_take_along_axis(arr, indices, axis): shape = [*shape, n] arr = xp.reshape(arr, (-1,)) - indices = xp.reshape(indices, (-1, n)) + if n != 0: + indices = xp.reshape(indices, (-1, n)) + else: + indices = xp.reshape(indices, (0, 0)) offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] indices = xp.reshape(offset + indices, (-1,)) From e7aeca024da0ca4dad30f4bcf428e730d49faf64 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 25 Sep 2024 03:44:39 -0400 Subject: [PATCH 09/14] fix reshape Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/exclude_mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index a8e8dc7ef3..e744a726f6 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -127,7 +127,8 @@ def build_type_exclude_mask( index = xp.reshape( xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei) ) - type_j = xp_take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) + type_j = xp_take_along_axis(ae, index, axis=1) + type_j = xp.reshape(type_j, (nf, nloc, nnei)) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) type_ij = xp.reshape(type_ij, (nf, nloc * nnei)) From bac980e6c4f94fb0b13bcf0eb0b88d00dbfcf20a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Sep 2024 03:32:55 -0400 Subject: [PATCH 10/14] apply flax.nnx.Module Signed-off-by: Jinzhe Zeng --- deepmd/jax/common.py | 40 ++++++++++++++++++++++++++++++++ deepmd/jax/descriptor/dpa1.py | 6 +++++ deepmd/jax/env.py | 4 ++++ deepmd/jax/utils/exclude_mask.py | 2 ++ deepmd/jax/utils/network.py | 20 +++++++++++++--- deepmd/jax/utils/type_embed.py | 2 ++ pyproject.toml | 1 + 7 files changed, 72 insertions(+), 3 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 8c3860cf39..0af7921c6d 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -6,8 +6,12 @@ import numpy as np +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.jax.env import ( jnp, + nnx, ) @@ -35,3 +39,39 @@ def to_jax_array(array: Optional[np.ndarray]) -> Optional[jnp.ndarray]: if array is None: return None return jnp.array(array) + + +def flax_module( + module: NativeOP, +) -> nnx.Module: + """Convert a NativeOP to a Flax module. + + Parameters + ---------- + module : NativeOP + The NativeOP to convert. + + Returns + ------- + flax.nnx.Module + The Flax module. + + Examples + -------- + >>> @flax_module + ... class MyModule(NativeOP): + ... pass + """ + metas = set() + if not issubclass(type(nnx.Module), type(module)): + metas.add(type(module)) + if not issubclass(type(module), type(nnx.Module)): + metas.add(type(nnx.Module)) + + class MixedMetaClass(*metas): + pass + + class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): + pass + + return FlaxModule diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index 73ef6055e5..a9b0404970 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -13,6 +13,7 @@ NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, ) from deepmd.jax.common import ( + flax_module, to_jax_array, ) from deepmd.jax.utils.exclude_mask import ( @@ -28,6 +29,7 @@ ) +@flax_module class GatedAttentionLayer(GatedAttentionLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"in_proj", "out_proj"}: @@ -35,6 +37,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name == "attention_layer": @@ -44,6 +47,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class NeighborGatedAttention(NeighborGatedAttentionDP): def __setattr__(self, name: str, value: Any) -> None: if name == "attention_layers": @@ -53,6 +57,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class DescrptBlockSeAtten(DescrptBlockSeAttenDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: @@ -71,6 +76,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: if name == "se_atten": diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 34e4aa6240..5a5a7f6bf0 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -5,10 +5,14 @@ import jax import jax.numpy as jnp +from flax import ( + nnx, +) jax.config.update("jax_enable_x64", True) __all__ = [ "jax", "jnp", + "nnx", ] diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py index 6519648514..cac4cee092 100644 --- a/deepmd/jax/utils/exclude_mask.py +++ b/deepmd/jax/utils/exclude_mask.py @@ -5,10 +5,12 @@ from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.jax.common import ( + flax_module, to_jax_array, ) +@flax_module class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 6517573b38..fc6e168c7b 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -17,10 +17,12 @@ make_multilayer_network, ) from deepmd.jax.common import ( + flax_module, to_jax_array, ) +@flax_module class NativeLayer(NativeLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: @@ -28,11 +30,22 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) -NativeNet = make_multilayer_network(NativeLayer, NativeOP) -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) -FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) +@flax_module +class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): + pass + + +@flax_module +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +@flax_module +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass +@flax_module class NetworkCollection(NetworkCollectionDP): NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { "network": NativeNet, @@ -41,5 +54,6 @@ class NetworkCollection(NetworkCollectionDP): } +@flax_module class LayerNorm(LayerNormDP, NativeLayer): pass diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index bc7c469524..3143460244 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( + flax_module, to_jax_array, ) from deepmd.jax.utils.network import ( @@ -12,6 +13,7 @@ ) +@flax_module class TypeEmbedNet(TypeEmbedNetDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"econf_tebd"}: diff --git a/pyproject.toml b/pyproject.toml index 28fe114e01..9fa1425c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ cu12 = [ ] jax = [ 'jax>=0.4.33;python_version>="3.10"', + 'flax>=0.8.0;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts] From 723d6ed9888abb885ac5df49b997a54c6fb00c3e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Sep 2024 06:36:09 -0400 Subject: [PATCH 11/14] fix metaclass Signed-off-by: Jinzhe Zeng --- deepmd/jax/common.py | 10 ++++++++-- deepmd/jax/utils/network.py | 3 --- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 0af7921c6d..9c144a41d1 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, Optional, overload, ) @@ -69,9 +70,14 @@ def flax_module( metas.add(type(nnx.Module)) class MixedMetaClass(*metas): - pass + def __call__(self, *args, **kwargs): + return type(nnx.Module).__call__(self, *args, **kwargs) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): - pass + def __init_subclass__(cls, **kwargs) -> None: + return super().__init_subclass__(**kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + return super().__setattr__(name, value) return FlaxModule diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index fc6e168c7b..bbd6419663 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -35,12 +35,10 @@ class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): pass -@flax_module class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): pass -@flax_module class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): pass @@ -54,6 +52,5 @@ class NetworkCollection(NetworkCollectionDP): } -@flax_module class LayerNorm(LayerNormDP, NativeLayer): pass From d159ffd90236ff9f8515cc36c4be7d5860ee0e4b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Sep 2024 17:34:42 -0400 Subject: [PATCH 12/14] set w, b, idt to Param Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/common.py | 2 +- deepmd/dpmodel/utils/network.py | 4 ++-- deepmd/jax/utils/network.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index d9d57d2d6c..8391c56443 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -78,7 +78,7 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: """ if x is None: return None - return np.asarray(x) + return np.from_dlpack(x) __all__ = [ diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 66104cb01c..ee7abd5f95 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -385,8 +385,8 @@ def serialize(self) -> dict: The serialized layer. """ data = { - "w": np.array(self.w), - "b": np.array(self.b), + "w": to_numpy_array(self.w), + "b": to_numpy_array(self.b), } return { "@class": "LayerNorm", diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index bbd6419663..2fce6831fe 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -20,6 +20,20 @@ flax_module, to_jax_array, ) +from deepmd.jax.env import ( + nnx, +) + + +class ArrayAPIParam(nnx.Param): + def __array_namespace__(self, *args, **kwargs): + return self.value.__array_namespace__(*args, **kwargs) + + def __dlpack__(self, *args, **kwargs): + return self.value.__dlpack__(*args, **kwargs) + + def __dlpack_device__(self, *args, **kwargs): + return self.value.__dlpack_device__(*args, **kwargs) @flax_module @@ -27,6 +41,8 @@ class NativeLayer(NativeLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIParam(value) return super().__setattr__(name, value) From 13e85f70d66656cdfbbd38a619a5ce4316bb5fcb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Sep 2024 17:44:08 -0400 Subject: [PATCH 13/14] __array__ Signed-off-by: Jinzhe Zeng --- deepmd/jax/utils/network.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 2fce6831fe..887ad7147e 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -26,6 +26,9 @@ class ArrayAPIParam(nnx.Param): + def __array__(self, *args, **kwargs): + return self.value.__array__(*args, **kwargs) + def __array_namespace__(self, *args, **kwargs): return self.value.__array_namespace__(*args, **kwargs) From d007f81c39afb7b47e4efd69d95280c76dfbd6ca Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Sep 2024 19:21:35 -0400 Subject: [PATCH 14/14] revert --- deepmd/dpmodel/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 8391c56443..d9d57d2d6c 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -78,7 +78,7 @@ def to_numpy_array(x: Any) -> Optional[np.ndarray]: """ if x is None: return None - return np.from_dlpack(x) + return np.asarray(x) __all__ = [