diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 2883f01b5a..d60a9c909a 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -51,7 +51,7 @@ jobs: - run: | export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') - source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch] mpi4py + source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py env: DP_VARIANT: cuda DP_ENABLE_NATIVE_OPTIMIZATION: 1 diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 36f9bd78b8..8274921909 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -28,7 +28,7 @@ jobs: source/install/uv_with_retry.sh pip install --system mpich source/install/uv_with_retry.sh pip install --system "torch==2.3.0+cpu.cxx11.abi" -i https://download.pytorch.org/whl/ export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test] horovod[tensorflow-cpu] mpi4py + source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py new file mode 100644 index 0000000000..ece0761772 --- /dev/null +++ b/deepmd/backend/jax.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + List, + Type, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("jax") +class JAXBackend(Backend): + """JAX backend.""" + + name = "JAX" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature(0) + # Backend.Feature.ENTRY_POINT + # | Backend.Feature.DEEP_EVAL + # | Backend.Feature.NEIGHBOR_STAT + # | Backend.Feature.IO + ) + """The features of the backend.""" + suffixes: ClassVar[List[str]] = [] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("jax") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + raise NotImplementedError + + @property + def deep_eval(self) -> Type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + raise NotImplementedError + + @property + def neighbor_stat(self) -> Type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + raise NotImplementedError + + @property + def serialize_hook(self) -> Callable[[str], dict]: + """The serialize hook to convert the model file to a dictionary. + + Returns + ------- + Callable[[str], dict] + The serialize hook of the backend. + """ + raise NotImplementedError + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + """The deserialize hook to convert the dictionary to a model file. + + Returns + ------- + Callable[[str, dict], None] + The deserialize hook of the backend. + """ + raise NotImplementedError diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 56cb8ec1e9..d9d57d2d6c 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -3,6 +3,10 @@ ABC, abstractmethod, ) +from typing import ( + Any, + Optional, +) import ml_dtypes import numpy as np @@ -59,6 +63,24 @@ def __call__(self, *args, **kwargs): return self.call(*args, **kwargs) +def to_numpy_array(x: Any) -> Optional[np.ndarray]: + """Convert an array to a NumPy array. + + Parameters + ---------- + x : Any + The array to be converted. + + Returns + ------- + Optional[np.ndarray] + The NumPy array. + """ + if x is None: + return None + return np.asarray(x) + + __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 6f0269971e..07421e0b13 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -15,6 +15,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -22,6 +23,12 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + support_array_api, +) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -105,9 +112,9 @@ def serialize(self) -> dict: The serialized layer. """ data = { - "w": self.w, - "b": self.b, - "idt": self.idt, + "w": to_numpy_array(self.w), + "b": to_numpy_array(self.b), + "idt": to_numpy_array(self.idt), } return { "@class": "Layer", @@ -215,6 +222,7 @@ def dim_in(self) -> int: def dim_out(self) -> int: return self.w.shape[1] + @support_array_api(version="2022.12") def call(self, x: np.ndarray) -> np.ndarray: """Forward pass. @@ -230,11 +238,12 @@ def call(self, x: np.ndarray) -> np.ndarray: """ if self.w is None or self.activation_function is None: raise ValueError("w, b, and activation_function must be set") + xp = array_api_compat.array_namespace(x) fn = get_activation_fn(self.activation_function) y = ( - np.matmul(x, self.w) + self.b + xp.matmul(x, self.w) + self.b if self.b is not None - else np.matmul(x, self.w) + else xp.matmul(x, self.w) ) y = fn(y) if self.idt is not None: @@ -242,47 +251,64 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.resnet and self.w.shape[1] == self.w.shape[0]: y += x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += np.concatenate([x, x], axis=-1) + y += xp.concatenate([x, x], axis=-1) return y +@support_array_api(version="2022.12") def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]: activation_function = activation_function.lower() if activation_function == "tanh": - return np.tanh + + def fn(x): + xp = array_api_compat.array_namespace(x) + return xp.tanh(x) + + return fn elif activation_function == "relu": def fn(x): + xp = array_api_compat.array_namespace(x) # https://stackoverflow.com/a/47936476/9567349 - return x * (x > 0) + return x * xp.astype(x > 0, x.dtype) return fn elif activation_function in ("gelu", "gelu_tf"): def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) + return ( + 0.5 + * x + * (1 + xp.tanh(xp.sqrt(xp.asarray(2 / xp.pi)) * (x + 0.044715 * x**3))) + ) return fn elif activation_function == "relu6": def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return np.minimum(np.maximum(x, 0), 6) + return xp.where( + x < 0, xp.full_like(x, 0), xp.where(x > 6, xp.full_like(x, 6), x) + ) return fn elif activation_function == "softplus": def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return np.log(1 + np.exp(x)) + return xp.log(1 + xp.exp(x)) return fn elif activation_function == "sigmoid": def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return 1 / (1 + np.exp(-x)) + return 1 / (1 + xp.exp(-x)) return fn elif activation_function.lower() in ("none", "linear"): diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index 2e695171d6..e11c415cfd 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -5,8 +5,12 @@ Union, ) +import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + support_array_api, +) from deepmd.dpmodel.common import ( PRECISION_DICT, NativeOP, @@ -92,16 +96,18 @@ def __init__( bias=self.use_tebd_bias, ) + @support_array_api(version="2022.12") def call(self) -> np.ndarray: """Compute the type embedding network.""" + sample_array = self.embedding_net[0]["w"] + xp = array_api_compat.array_namespace(sample_array) if not self.use_econf_tebd: - embed = self.embedding_net( - np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision]) - ) + embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype)) else: embed = self.embedding_net(self.econf_tebd) if self.padding: - embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant") + embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) + embed = xp.concatenate([embed, embed_pad], axis=0) return embed @classmethod diff --git a/deepmd/jax/__init__.py b/deepmd/jax/__init__.py new file mode 100644 index 0000000000..2ff078e797 --- /dev/null +++ b/deepmd/jax/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""JAX backend.""" diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py new file mode 100644 index 0000000000..550b168b29 --- /dev/null +++ b/deepmd/jax/common.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Union, + overload, +) + +import numpy as np + +from deepmd.jax.env import ( + jnp, +) + + +@overload +def to_jax_array(array: np.ndarray) -> jnp.ndarray: ... + + +@overload +def to_jax_array(array: None) -> None: ... + + +def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]: + """Convert a numpy array to a JAX array. + + Parameters + ---------- + array : np.ndarray + The numpy array to convert. + + Returns + ------- + jnp.ndarray + The JAX tensor. + """ + if array is None: + return None + return jnp.array(array) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py new file mode 100644 index 0000000000..34e4aa6240 --- /dev/null +++ b/deepmd/jax/env.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + +import jax +import jax.numpy as jnp + +jax.config.update("jax_enable_x64", True) + +__all__ = [ + "jax", + "jnp", +] diff --git a/deepmd/jax/utils/__init__.py b/deepmd/jax/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py new file mode 100644 index 0000000000..629b51b8cd --- /dev/null +++ b/deepmd/jax/utils/network.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) +from deepmd.jax.common import ( + to_jax_array, +) + + +class NativeLayer(NativeLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"}: + value = to_jax_array(value) + return super().__setattr__(name, value) + + +NativeNet = make_multilayer_network(NativeLayer, NativeOP) +EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py new file mode 100644 index 0000000000..bc7c469524 --- /dev/null +++ b/deepmd/jax/utils/type_embed.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.network import ( + EmbeddingNet, +) + + +class TypeEmbedNet(TypeEmbedNetDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"econf_tebd"}: + value = to_jax_array(value) + if name in {"embedding_net"}: + value = EmbeddingNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index c3d603dadd..9bdc80195f 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -484,7 +484,7 @@ def collect_single_finetune_params( if i != "_extra_state" and f".{_model_key}." in i ] for item_key in target_keys: - if _new_fitting and ".fitting_net." in item_key: + if _new_fitting and (".descriptor." not in item_key): # print(f'Keep {item_key} in old model!') _new_state_dict[item_key] = ( _random_state_dict[item_key].clone().detach() diff --git a/pyproject.toml b/pyproject.toml index f181b616a3..28fe114e01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,9 @@ cu12 = [ "nvidia-cudnn-cu12<9", "nvidia-cuda-nvcc-cu12", ] +jax = [ + 'jax>=0.4.33;python_version>="3.10"', +] [tool.deepmd_build_backend.scripts] dp = "deepmd.main:main" diff --git a/source/tests/common/dpmodel/array_api/test_activation_functions.py b/source/tests/common/dpmodel/array_api/test_activation_functions.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/common/dpmodel/array_api/test_activation_functions.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index edafc7c02e..e8873e528a 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -35,6 +35,7 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() +INSTALLED_JAX = Backend.get_backend("jax")().is_available() if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT): raise ImportError("TensorFlow or PyTorch should be tested in the CI") @@ -57,6 +58,7 @@ "CommonTest", "INSTALLED_TF", "INSTALLED_PT", + "INSTALLED_JAX", ] @@ -71,6 +73,8 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[Optional[type]] """PyTorch model class.""" + jax_class: ClassVar[Optional[type]] + """JAX model class.""" args: ClassVar[Optional[Union[Argument, List[Argument]]]] """Arguments that maps to the `data`.""" skip_dp: ClassVar[bool] = False @@ -79,6 +83,9 @@ class CommonTest(ABC): """Whether to skip the TensorFlow model.""" skip_pt: ClassVar[bool] = not INSTALLED_PT """Whether to skip the PyTorch model.""" + # we may usually skip jax before jax is fully supported + skip_jax: ClassVar[bool] = True + """Whether to skip the JAX model.""" rtol = 1e-10 """Relative tolerance for comparing the return value. Override for float32.""" atol = 1e-10 @@ -149,12 +156,23 @@ def eval_pt(self, pt_obj: Any) -> Any: The object of PT """ + def eval_jax(self, jax_obj: Any) -> Any: + """Evaluate the return value of JAX. + + Parameters + ---------- + jax_obj : Any + The object of JAX + """ + raise NotImplementedError("Not implemented") + class RefBackend(Enum): """Reference backend.""" TF = 1 DP = 2 PT = 3 + JAX = 5 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> Tuple[np.ndarray, ...]: @@ -215,6 +233,11 @@ def get_dp_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_jax_ret_serialization_from_cls(self, obj): + ret = self.eval_jax(obj) + data = obj.serialize() + return ret, data + def get_reference_backend(self): """Get the reference backend. @@ -226,6 +249,8 @@ def get_reference_backend(self): return self.RefBackend.TF if not self.skip_pt: return self.RefBackend.PT + if not self.skip_jax: + return self.RefBackend.JAX raise ValueError("No available reference") def get_reference_ret_serialization(self, ref: RefBackend): @@ -359,6 +384,40 @@ def test_pt_self_consistent(self): else: self.assertEqual(rr1, rr2) + def test_jax_consistent_with_ref(self): + """Test whether JAX and reference are consistent.""" + if self.skip_jax: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.JAX: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + jax_obj = self.jax_class.deserialize(data1) + ret2 = self.eval_jax(jax_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.JAX) + data2 = jax_obj.serialize() + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_jax_self_consistent(self): + """Test whether JAX is self consistent.""" + if self.skip_jax: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.jax_class) + ret1, data1 = self.get_jax_ret_serialization_from_cls(obj1) + obj1 = self.jax_class.deserialize(data1) + ret2, data2 = self.get_jax_ret_serialization_from_cls(obj1) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def tearDown(self) -> None: """Clear the TF session.""" if not self.skip_tf: diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 3fcb9b2fa5..5630e913a8 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import sys import unittest import numpy as np @@ -12,6 +13,7 @@ GLOBAL_SEED, ) from .common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, parameterized, @@ -28,6 +30,10 @@ from deepmd.tf.env import ( tf, ) +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) @parameterized( @@ -57,3 +63,23 @@ def test_pt_consistent_with_ref(self): ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless( + sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8" + ) + def test_arary_api_strict(self): + import array_api_strict as xp + + xp.set_array_api_strict_flags( + api_version=get_activation_fn_dp.array_api_version + ) + input = xp.asarray(self.random_input) + test = get_activation_fn_dp(self.activation)(input) + np.testing.assert_allclose(self.ref, np.array(test), atol=1e-10) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_consistent_with_ref(self): + input = jnp.from_dlpack(self.random_input) + test = get_activation_fn_dp(self.activation)(input) + self.assertTrue(isinstance(test, jnp.ndarray)) + np.testing.assert_allclose(self.ref, np.from_dlpack(test), atol=1e-10) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 6583dddb5f..c66ef0fbaa 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -13,6 +13,7 @@ ) from .common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -30,6 +31,13 @@ from deepmd.tf.utils.type_embed import TypeEmbedNet as TypeEmbedNetTF else: TypeEmbedNetTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.utils.type_embed import TypeEmbedNet as TypeEmbedNetJAX +else: + TypeEmbedNetJAX = object @parameterized( @@ -63,7 +71,9 @@ def data(self) -> dict: tf_class = TypeEmbedNetTF dp_class = TypeEmbedNetDP pt_class = TypeEmbedNetPT + jax_class = TypeEmbedNetJAX args = type_embedding_args() + skip_jax = not INSTALLED_JAX @property def addtional_data(self) -> dict: @@ -103,6 +113,14 @@ def eval_pt(self, pt_obj: Any) -> Any: for x in (pt_obj(device=PT_DEVICE),) ] + def eval_jax(self, jax_obj: Any) -> Any: + out = jax_obj() + # ensure output is not numpy array + for x in (out,): + if isinstance(x, np.ndarray): + raise ValueError("Output is numpy array") + return [np.array(x) if isinstance(x, jnp.ndarray) else x for x in (out,)] + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: return (ret[0],) diff --git a/source/tests/pt/model/models/dpa1.json b/source/tests/pt/model/models/dpa1.json index 1321acbd53..a969c290ae 100644 --- a/source/tests/pt/model/models/dpa1.json +++ b/source/tests/pt/model/models/dpa1.json @@ -21,7 +21,8 @@ "activation_function": "tanh", "scaling_factor": 1.0, "normalize": true, - "temperature": 1.0 + "temperature": 1.0, + "seed": 1 }, "fitting_net": { "neuron": [ diff --git a/source/tests/pt/model/models/dpa2.json b/source/tests/pt/model/models/dpa2.json index 7495f5d78a..f83e319de3 100644 --- a/source/tests/pt/model/models/dpa2.json +++ b/source/tests/pt/model/models/dpa2.json @@ -42,6 +42,7 @@ "g1_out_conv": false, "g1_out_mlp": false }, + "seed": 1, "add_tebd_to_repinit_out": false }, "fitting_net": { diff --git a/source/tests/pt/model/test_descriptor_se_r.py b/source/tests/pt/model/test_descriptor_se_r.py index a2b9754714..f3692101c5 100644 --- a/source/tests/pt/model/test_descriptor_se_r.py +++ b/source/tests/pt/model/test_descriptor_se_r.py @@ -63,6 +63,7 @@ def test_consistency( resnet_dt=idt, old_impl=False, exclude_mask=em, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) @@ -130,6 +131,7 @@ def test_load_stat(self): precision=prec, resnet_dt=idt, old_impl=False, + seed=GLOBAL_SEED, ) dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) @@ -180,6 +182,7 @@ def test_jit( precision=prec, resnet_dt=idt, old_impl=False, + seed=GLOBAL_SEED, ) dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index cd3a032ecc..71da2781ac 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -87,6 +87,7 @@ def test_consistency( numb_fparam=nfp, numb_aparam=nap, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -139,6 +140,7 @@ def test_jit( numb_fparam=nfp, numb_aparam=nap, mixed_types=mixed_types, + seed=GLOBAL_SEED, ).to(env.DEVICE) torch.jit.script(ft0) @@ -180,6 +182,7 @@ def test_rot(self): numb_fparam=nfp, numb_aparam=nap, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -234,6 +237,7 @@ def test_permu(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -280,6 +284,7 @@ def test_trans(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -327,6 +332,7 @@ def setUp(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_dpa1.py b/source/tests/pt/model/test_dpa1.py index f1994504fc..b825885311 100644 --- a/source/tests/pt/model/test_dpa1.py +++ b/source/tests/pt/model/test_dpa1.py @@ -71,6 +71,7 @@ def test_consistency( use_econf_tebd=ect, type_map=["O", "H"] if ect else None, old_impl=False, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) @@ -125,6 +126,7 @@ def test_consistency( resnet_dt=idt, smooth_type_embedding=sm, old_impl=True, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0_state_dict = dd0.se_atten.state_dict() dd3_state_dict = dd3.se_atten.state_dict() @@ -210,6 +212,7 @@ def test_jit( use_econf_tebd=ect, type_map=["O", "H"] if ect else None, old_impl=False, + seed=GLOBAL_SEED, ) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index f11be532cb..0beb34c031 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -20,6 +20,9 @@ PRECISION_DICT, ) +from ...seed import ( + GLOBAL_SEED, +) from .test_env_mat import ( TestCaseSingleFrameWithNlist, ) @@ -152,6 +155,7 @@ def test_consistency( use_econf_tebd=ect, type_map=["O", "H"] if ect else None, old_impl=False, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) @@ -201,6 +205,7 @@ def test_consistency( add_tebd_to_repinit_out=False, precision=prec, old_impl=True, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0_state_dict = dd0.state_dict() dd3_state_dict = dd3.state_dict() @@ -346,6 +351,7 @@ def test_jit( use_econf_tebd=ect, type_map=["O", "H"] if ect else None, old_impl=False, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_embedding_net.py b/source/tests/pt/model/test_embedding_net.py index 77d14db2a4..3605316437 100644 --- a/source/tests/pt/model/test_embedding_net.py +++ b/source/tests/pt/model/test_embedding_net.py @@ -39,6 +39,9 @@ ) from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf +from ...seed import ( + GLOBAL_SEED, +) from ..test_finetune import ( energy_data_requirement, ) @@ -153,7 +156,7 @@ def test_consistency(self): sel=self.sel, neuron=self.filter_neuron, axis_neuron=self.axis_neuron, - seed=1, + seed=GLOBAL_SEED, ) dp_embedding, dp_force, dp_vars = base_se_a( descriptor=dp_d, diff --git a/source/tests/pt/model/test_ener_fitting.py b/source/tests/pt/model/test_ener_fitting.py index 07c0d19935..3255db2784 100644 --- a/source/tests/pt/model/test_ener_fitting.py +++ b/source/tests/pt/model/test_ener_fitting.py @@ -65,6 +65,7 @@ def test_consistency( mixed_types=mixed_types, exclude_types=et, neuron=nn, + seed=GLOBAL_SEED, ).to(env.DEVICE) ft1 = DPInvarFitting.deserialize(ft0.serialize()) ft2 = InvarFitting.deserialize(ft0.serialize()) @@ -168,6 +169,7 @@ def test_jit( numb_aparam=nap, mixed_types=mixed_types, exclude_types=et, + seed=GLOBAL_SEED, ).to(env.DEVICE) torch.jit.script(ft0) @@ -177,6 +179,7 @@ def test_get_set(self): self.nt, 3, 1, + seed=GLOBAL_SEED, ) rng = np.random.default_rng(GLOBAL_SEED) foo = rng.normal([3, 4]) diff --git a/source/tests/pt/model/test_permutation.py b/source/tests/pt/model/test_permutation.py index 2fbc5fde3c..6aec895041 100644 --- a/source/tests/pt/model/test_permutation.py +++ b/source/tests/pt/model/test_permutation.py @@ -88,6 +88,7 @@ "temperature": 1.0, "set_davg_zero": True, "type_one_side": True, + "seed": 1, }, "fitting_net": { "neuron": [24, 24, 24], @@ -155,6 +156,7 @@ "update_g2_has_attn": True, "attn2_has_gate": True, }, + "seed": 1, "add_tebd_to_repinit_out": False, }, "fitting_net": { @@ -207,6 +209,7 @@ "g1_out_conv": True, "g1_out_mlp": True, }, + "seed": 1, "add_tebd_to_repinit_out": False, }, "fitting_net": { @@ -235,6 +238,7 @@ "temperature": 1.0, "set_davg_zero": True, "type_one_side": True, + "seed": 1, }, "fitting_net": { "neuron": [24, 24, 24], @@ -264,6 +268,7 @@ "scaling_factor": 1.0, "normalize": True, "temperature": 1.0, + "seed": 1, }, { "type": "dpa2", @@ -296,6 +301,7 @@ "update_g2_has_attn": True, "attn2_has_gate": True, }, + "seed": 1, "add_tebd_to_repinit_out": False, }, ], diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index ba1bf2ea29..1ca563a8c2 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -77,6 +77,7 @@ def test_consistency( mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, + seed=GLOBAL_SEED, ).to(env.DEVICE) ft1 = DPPolarFitting.deserialize(ft0.serialize()) ft2 = PolarFittingNet.deserialize(ft0.serialize()) @@ -143,6 +144,7 @@ def test_jit( numb_aparam=nap, mixed_types=mixed_types, fit_diag=fit_diag, + seed=GLOBAL_SEED, ).to(env.DEVICE) torch.jit.script(ft0) @@ -186,6 +188,7 @@ def test_rot(self): mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, + seed=GLOBAL_SEED, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -248,6 +251,7 @@ def test_permu(self): mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -298,6 +302,7 @@ def test_trans(self): mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -347,6 +352,7 @@ def setUp(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = PolarModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_property_fitting.py b/source/tests/pt/model/test_property_fitting.py index 59a5b1b172..dfe2725f3b 100644 --- a/source/tests/pt/model/test_property_fitting.py +++ b/source/tests/pt/model/test_property_fitting.py @@ -32,6 +32,9 @@ to_numpy_array, ) +from ...seed import ( + GLOBAL_SEED, +) from .test_env_mat import ( TestCaseSingleFrameWithNlist, ) @@ -78,6 +81,7 @@ def test_consistency( bias_atom_p=bias_atom_p, intensive=intensive, bias_method=bias_method, + seed=GLOBAL_SEED, ).to(env.DEVICE) ft1 = DPProperFittingNet.deserialize(ft0.serialize()) @@ -146,6 +150,7 @@ def test_jit( mixed_types=self.dd0.mixed_types(), intensive=intensive, bias_method=bias_method, + seed=GLOBAL_SEED, ).to(env.DEVICE) torch.jit.script(ft0) @@ -199,6 +204,7 @@ def test_trans(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -266,6 +272,7 @@ def test_rot(self): mixed_types=self.dd0.mixed_types(), intensive=intensive, bias_method=bias_method, + seed=GLOBAL_SEED, ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -320,6 +327,7 @@ def test_permu(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -367,6 +375,7 @@ def test_trans(self): numb_fparam=0, numb_aparam=0, mixed_types=self.dd0.mixed_types(), + seed=GLOBAL_SEED, ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -417,6 +426,7 @@ def setUp(self): numb_aparam=0, mixed_types=self.dd0.mixed_types(), intensive=True, + seed=GLOBAL_SEED, ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = PropertyModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_se_atten_v2.py b/source/tests/pt/model/test_se_atten_v2.py index caecd0a118..f9857fc728 100644 --- a/source/tests/pt/model/test_se_atten_v2.py +++ b/source/tests/pt/model/test_se_atten_v2.py @@ -16,6 +16,9 @@ PRECISION_DICT, ) +from ...seed import ( + GLOBAL_SEED, +) from .test_env_mat import ( TestCaseSingleFrameWithNlist, ) @@ -64,6 +67,7 @@ def test_consistency( use_econf_tebd=ect, type_map=["O", "H"] if ect else None, old_impl=False, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) @@ -135,6 +139,7 @@ def test_jit( use_econf_tebd=ect, type_map=["O", "H"] if ect else None, old_impl=False, + seed=GLOBAL_SEED, ) dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_se_e2_a.py b/source/tests/pt/model/test_se_e2_a.py index 75d47c9054..abe13ce86e 100644 --- a/source/tests/pt/model/test_se_e2_a.py +++ b/source/tests/pt/model/test_se_e2_a.py @@ -60,6 +60,7 @@ def test_consistency( resnet_dt=idt, old_impl=False, exclude_types=em, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.sea.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.sea.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) @@ -113,6 +114,7 @@ def test_consistency( precision=prec, resnet_dt=idt, old_impl=True, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0_state_dict = dd0.sea.state_dict() dd3_state_dict = dd3.sea.state_dict() @@ -168,6 +170,7 @@ def test_jit( precision=prec, resnet_dt=idt, old_impl=False, + seed=GLOBAL_SEED, ) dd0.sea.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.sea.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/test_se_t.py b/source/tests/pt/model/test_se_t.py index 0d6c87ba8d..d3968d7f03 100644 --- a/source/tests/pt/model/test_se_t.py +++ b/source/tests/pt/model/test_se_t.py @@ -63,6 +63,7 @@ def test_consistency( precision=prec, resnet_dt=idt, exclude_types=em, + seed=GLOBAL_SEED, ).to(env.DEVICE) dd0.seat.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.seat.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) @@ -131,6 +132,7 @@ def test_jit( self.sel, precision=prec, resnet_dt=idt, + seed=GLOBAL_SEED, ) dd0.seat.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.seat.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) diff --git a/source/tests/pt/model/water/se_atten.json b/source/tests/pt/model/water/se_atten.json index 71cee94d8b..4b4c54e0d2 100644 --- a/source/tests/pt/model/water/se_atten.json +++ b/source/tests/pt/model/water/se_atten.json @@ -24,7 +24,8 @@ "activation_function": "tanh", "scaling_factor": 1.0, "normalize": false, - "temperature": 1.0 + "temperature": 1.0, + "seed": 1 }, "fitting_net": { "neuron": [ diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 0833200d47..fa9e5c138a 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -448,5 +448,73 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestPropFintuFromEnerModel(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa1) + self.config["model"]["type_map"] = ["H", "C", "N", "O"] + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + property_input = str(Path(__file__).parent / "property/input.json") + with open(property_input) as f: + self.config_property = json.load(f) + prop_data_file = [str(Path(__file__).parent / "property/single")] + self.config_property["training"]["training_data"]["systems"] = prop_data_file + self.config_property["training"]["validation_data"]["systems"] = prop_data_file + self.config_property["model"]["descriptor"] = deepcopy(model_dpa1["descriptor"]) + self.config_property["training"]["numb_steps"] = 1 + self.config_property["training"]["save_freq"] = 1 + + def test_dp_train(self): + # test training from scratch + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + state_dict_trained = trainer.wrapper.model.state_dict() + + # test fine-tuning using diffferent fitting_net, here using property fitting + finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt" + self.config_property["model"], finetune_links = get_finetune_rules( + finetune_model, + self.config_property["model"], + model_branch="RANDOM", + ) + trainer_finetune = get_trainer( + deepcopy(self.config_property), + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + + # check parameters + state_dict_finetuned = trainer_finetune.wrapper.model.state_dict() + for state_key in state_dict_finetuned: + if ( + "out_bias" not in state_key + and "out_std" not in state_key + and "fitting" not in state_key + ): + torch.testing.assert_close( + state_dict_trained[state_key], + state_dict_finetuned[state_key], + ) + + # check running + trainer_finetune.run() + + def tearDown(self): + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + if __name__ == "__main__": unittest.main()