From 543276a5601848075a442cad4d7feaeed6aa6457 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Fri, 23 Feb 2024 17:17:16 +0800 Subject: [PATCH] Feat: add polar consistency test (#3327) This PR is to add cross framework consistency test on PolarFittingNet. Note: `shift_diag` not yet implemented in PT. --------- Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../dpmodel/fitting/polarizability_fitting.py | 15 +- deepmd/pt/model/task/polarizability.py | 16 +- deepmd/tf/fit/polar.py | 97 ++++++++- source/tests/consistent/fitting/test_polar.py | 192 ++++++++++++++++++ source/tests/tf/test_polar_se_a.py | 4 +- source/tests/tf/test_polar_se_a_tebd.py | 4 +- 6 files changed, 309 insertions(+), 19 deletions(-) create mode 100644 source/tests/consistent/fitting/test_polar.py diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 9811e8e1c8..0b22fa03f8 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -102,6 +102,8 @@ def __init__( fit_diag: bool = True, scale: Optional[List[float]] = None, shift_diag: bool = True, + # not used + seed: Optional[int] = None, ): # seed, uniform_seed are not included if tot_ener_zero: @@ -119,9 +121,16 @@ def __init__( if self.scale is None: self.scale = [1.0 for _ in range(ntypes)] else: - assert ( - isinstance(self.scale, list) and len(self.scale) == ntypes - ), "Scale should be a list of length ntypes." + if isinstance(self.scale, list): + assert ( + len(self.scale) == ntypes + ), "Scale should be a list of length ntypes." + elif isinstance(self.scale, float): + self.scale = [self.scale for _ in range(ntypes)] + else: + raise ValueError( + "Scale must be a list of float of length ntypes or a float." + ) self.scale = np.array(self.scale, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape( ntypes, 1 ) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index c240567903..9b2d2635cb 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -3,6 +3,7 @@ from typing import ( List, Optional, + Union, ) import torch @@ -85,7 +86,7 @@ def __init__( seed: Optional[int] = None, exclude_types: List[int] = [], fit_diag: bool = True, - scale: Optional[List[float]] = None, + scale: Optional[Union[List[float], float]] = None, shift_diag: bool = True, **kwargs, ): @@ -95,9 +96,16 @@ def __init__( if self.scale is None: self.scale = [1.0 for _ in range(ntypes)] else: - assert ( - isinstance(self.scale, list) and len(self.scale) == ntypes - ), "Scale should be a list of length ntypes." + if isinstance(self.scale, list): + assert ( + len(self.scale) == ntypes + ), "Scale should be a list of length ntypes." + elif isinstance(self.scale, float): + self.scale = [self.scale for _ in range(ntypes)] + else: + raise ValueError( + "Scale must be a list of float of length ntypes or a float." + ) self.scale = torch.tensor( self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ).view(ntypes, 1) diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index ae02c02064..2f4dbfc280 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -42,8 +42,12 @@ class PolarFittingSeA(Fitting): Parameters ---------- - descrpt : tf.Tensor - The descrptor + ntypes + The ntypes of the descrptor :math:`\mathcal{D}` + dim_descrpt + The dimension of the descrptor :math:`\mathcal{D}` + embedding_width + The rotation matrix dimension of the descrptor :math:`\mathcal{D}` neuron : List[int] Number of neurons in each hidden layer of the fitting net resnet_dt : bool @@ -69,7 +73,9 @@ class PolarFittingSeA(Fitting): def __init__( self, - descrpt: tf.Tensor, + ntypes: int, + dim_descrpt: int, + embedding_width: int, neuron: List[int] = [120, 120, 120], resnet_dt: bool = True, sel_type: Optional[List[int]] = None, @@ -84,8 +90,8 @@ def __init__( **kwargs, ) -> None: """Constructor.""" - self.ntypes = descrpt.get_ntypes() - self.dim_descrpt = descrpt.get_dim_out() + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt self.n_neuron = neuron self.resnet_dt = resnet_dt self.sel_type = sel_type @@ -96,6 +102,7 @@ def __init__( # self.diag_shift = diag_shift self.shift_diag = shift_diag self.scale = scale + self.activation_function_name = activation_function self.fitting_activation_fn = get_activation_func(activation_function) self.fitting_precision = get_precision(precision) if self.sel_type is None: @@ -104,7 +111,19 @@ def __init__( [ii in self.sel_type for ii in range(self.ntypes)], dtype=bool ) if self.scale is None: - self.scale = [1.0 for ii in range(self.ntypes)] + self.scale = np.array([1.0 for ii in range(self.ntypes)]) + else: + if isinstance(self.scale, list): + assert ( + len(self.scale) == ntypes + ), "Scale should be a list of length ntypes." + elif isinstance(self.scale, float): + self.scale = [self.scale for _ in range(ntypes)] + else: + raise ValueError( + "Scale must be a list of float of length ntypes or a float." + ) + self.scale = np.array(self.scale) # if self.diag_shift is None: # self.diag_shift = [0.0 for ii in range(self.ntypes)] if not isinstance(self.sel_type, list): @@ -115,10 +134,7 @@ def __init__( ) # self.ntypes x 1, store the average diagonal value # if type(self.diag_shift) is not list: # self.diag_shift = [self.diag_shift] - if not isinstance(self.scale, list): - self.scale = [self.scale for ii in range(self.ntypes)] - self.scale = np.array(self.scale) - self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1() + self.dim_rot_mat_1 = embedding_width self.dim_rot_mat = self.dim_rot_mat_1 * 3 self.useBN = False self.fitting_net_variables = None @@ -509,6 +525,67 @@ def get_loss(self, loss: dict, lr) -> Loss: label_name="polarizability", ) + def serialize(self, suffix: str) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + data = { + "var_name": "polar", + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "embedding_width": self.dim_rot_mat_1, + # very bad design: type embedding is not passed to the class + # TODO: refactor the class + "mixed_types": False, + "dim_out": 3, + "neuron": self.n_neuron, + "resnet_dt": self.resnet_dt, + "activation_function": self.activation_function_name, + "precision": self.fitting_precision.name, + "exclude_types": [], + "fit_diag": self.fit_diag, + "scale": list(self.scale), + "shift_diag": self.shift_diag, + "nets": self.serialize_network( + ntypes=self.ntypes, + # TODO: consider type embeddings + ndim=1, + in_dim=self.dim_descrpt, + out_dim=self.dim_rot_mat_1, + neuron=self.n_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.resnet_dt, + variables=self.fitting_net_variables, + suffix=suffix, + ), + } + return data + + @classmethod + def deserialize(cls, data: dict, suffix: str): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + fitting = cls(**data) + fitting.fitting_net_variables = cls.deserialize_network( + data["nets"], + suffix=suffix, + ) + return fitting + class GlobalPolarFittingSeA: r"""Fit the system polarizability with descriptor se_a. diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py new file mode 100644 index 0000000000..7bc11961eb --- /dev/null +++ b/source/tests/consistent/fitting/test_polar.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np + +from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + DipoleFittingTest, +) + +if INSTALLED_PT: + import torch + + from deepmd.pt.model.task.polarizability import PolarFittingNet as PolarFittingPT + from deepmd.pt.utils.env import DEVICE as PT_DEVICE +else: + PolarFittingPT = object +if INSTALLED_TF: + from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF +else: + PolarFittingTF = object +from deepmd.utils.argcheck import ( + fitting_polar, +) + + +@parameterized( + (True, False), # resnet_dt + ("float64", "float32"), # precision + (True, False), # mixed_types +) +class TestPolar(CommonTest, DipoleFittingTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + return { + "neuron": [5, 5, 5], + "resnet_dt": resnet_dt, + "precision": precision, + "seed": 20240217, + } + + @property + def skip_tf(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + # TODO: mixed_types + return mixed_types or CommonTest.skip_pt + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + return CommonTest.skip_pt + + tf_class = PolarFittingTF + dp_class = PolarFittingDP + pt_class = PolarFittingPT + args = fitting_polar() + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.gr = np.ones((1, 6, 30, 3), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + # inconsistent if not sorted + self.atype.sort() + + @property + def addtional_data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + return { + "ntypes": self.ntypes, + "dim_descrpt": self.inputs.shape[-1], + "mixed_types": mixed_types, + "var_name": "polar", + "embedding_width": 30, + } + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + return self.build_tf_fitting( + obj, + self.inputs.ravel(), + self.gr, + self.natoms, + self.atype, + None, + suffix, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + return ( + pt_obj( + torch.from_numpy(self.inputs).to(device=PT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), + torch.from_numpy(self.gr).to(device=PT_DEVICE), + None, + )["polar"] + .detach() + .cpu() + .numpy() + ) + + def eval_dp(self, dp_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + return dp_obj( + self.inputs, + self.atype.reshape(1, -1), + self.gr, + None, + )["polar"] + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + if backend == self.RefBackend.TF: + # shape is not same + ret = ret[0].reshape(-1, self.natoms[0], 1) + return (ret,) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-3 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-3 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/tf/test_polar_se_a.py b/source/tests/tf/test_polar_se_a.py index 04487dec7b..031d9330bc 100644 --- a/source/tests/tf/test_polar_se_a.py +++ b/source/tests/tf/test_polar_se_a.py @@ -55,7 +55,9 @@ def test_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["fitting_net"].pop("type", None) descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True) - jdata["model"]["fitting_net"]["descrpt"] = descrpt + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() + jdata["model"]["fitting_net"]["embedding_width"] = descrpt.get_dim_rot_mat_1() fitting = PolarFittingSeA(**jdata["model"]["fitting_net"], uniform_seed=True) model = PolarModel(descrpt, fitting) diff --git a/source/tests/tf/test_polar_se_a_tebd.py b/source/tests/tf/test_polar_se_a_tebd.py index 38c3ae20ef..c7aa94d5e8 100644 --- a/source/tests/tf/test_polar_se_a_tebd.py +++ b/source/tests/tf/test_polar_se_a_tebd.py @@ -65,7 +65,9 @@ def test_model(self): jdata["model"]["descriptor"].pop("type", None) jdata["model"]["fitting_net"].pop("type", None) descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True) - jdata["model"]["fitting_net"]["descrpt"] = descrpt + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() + jdata["model"]["fitting_net"]["embedding_width"] = descrpt.get_dim_rot_mat_1() fitting = PolarFittingSeA(**jdata["model"]["fitting_net"], uniform_seed=True) typeebd_param = jdata["model"]["type_embedding"] typeebd = TypeEmbedNet(