diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 982a4eb834..761db2f6aa 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -6,7 +6,8 @@ import numpy as np -from deepmd.common import ( +from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, ) @@ -21,6 +22,13 @@ "int64": np.int64, "default": GLOBAL_NP_FLOAT_PRECISION, } +RESERVED_PRECISON_DICT = { + np.float16: "float16", + np.float32: "float32", + np.float64: "float64", + np.int32: "int32", + np.int64: "int64", +} DEFAULT_PRECISION = "float64" @@ -35,3 +43,13 @@ def call(self, *args, **kwargs): def __call__(self, *args, **kwargs): """Forward pass in NumPy implementation.""" return self.call(*args, **kwargs) + + +__all__ = [ + "GLOBAL_NP_FLOAT_PRECISION", + "GLOBAL_ENER_FLOAT_PRECISION", + "PRECISION_DICT", + "RESERVED_PRECISON_DICT", + "DEFAULT_PRECISION", + "NativeOP", +] diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 7928644061..1261906148 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -1,17 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Dict, + List, Optional, + Tuple, ) import numpy as np from deepmd.dpmodel.common import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, + PRECISION_DICT, + RESERVED_PRECISON_DICT, NativeOP, ) from deepmd.dpmodel.output_def import ( ModelOutputDef, OutputVariableCategory, + OutputVariableOperation, + check_operation_applied, ) from deepmd.dpmodel.utils import ( build_neighbor_list, @@ -59,6 +67,10 @@ def __init__( *args, **kwargs, ) + self.precision_dict = PRECISION_DICT + self.reverse_precision_dict = RESERVED_PRECISON_DICT + self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION + self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION def model_output_def(self): """Get the output def for the model.""" @@ -115,15 +127,19 @@ def call( """ nframes, nloc = atype.shape[:2] - if box is not None: + cc, bb, fp, ap, input_prec = self.input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + del coord, box, fparam, aparam + if bb is not None: coord_normalized = normalize_coord( - coord.reshape(nframes, nloc, 3), - box.reshape(nframes, 3, 3), + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), ) else: - coord_normalized = coord.copy() + coord_normalized = cc.copy() extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, self.get_rcut() + coord_normalized, atype, bb, self.get_rcut() ) nlist = build_neighbor_list( extended_coord, @@ -139,8 +155,8 @@ def call( extended_atype, nlist, mapping, - fparam=fparam, - aparam=aparam, + fparam=fp, + aparam=ap, do_atomic_virial=do_atomic_virial, ) model_predict = communicate_extended_output( @@ -149,6 +165,7 @@ def call( mapping, do_atomic_virial=do_atomic_virial, ) + model_predict = self.output_type_cast(model_predict, input_prec) return model_predict def call_lower( @@ -192,22 +209,95 @@ def call_lower( nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.reshape(nframes, -1, 3) nlist = self.format_nlist(extended_coord, extended_atype, nlist) + cc_ext, _, fp, ap, input_prec = self.input_type_cast( + extended_coord, fparam=fparam, aparam=aparam + ) + del extended_coord, fparam, aparam atomic_ret = self.forward_atomic( - extended_coord, + cc_ext, extended_atype, nlist, mapping=mapping, - fparam=fparam, - aparam=aparam, + fparam=fp, + aparam=ap, ) model_predict = fit_output_to_model_output( atomic_ret, self.fitting_output_def(), - extended_coord, + cc_ext, do_atomic_virial=do_atomic_virial, ) + model_predict = self.output_type_cast(model_predict, input_prec) return model_predict + def input_type_cast( + self, + coord: np.ndarray, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Tuple[ + np.ndarray, + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + str, + ]: + """Cast the input data to global float type.""" + input_prec = self.reverse_precision_dict[ + self.precision_dict[coord.dtype.name] + ] + ### + ### type checking would not pass jit, convert to coord prec anyway + ### + _lst: List[Optional[np.ndarray]] = [ + vv.astype(coord.dtype) if vv is not None else None + for vv in [box, fparam, aparam] + ] + box, fparam, aparam = _lst + if ( + input_prec + == self.reverse_precision_dict[self.global_np_float_precision] + ): + return coord, box, fparam, aparam, input_prec + else: + pp = self.global_np_float_precision + return ( + coord.astype(pp), + box.astype(pp) if box is not None else None, + fparam.astype(pp) if fparam is not None else None, + aparam.astype(pp) if aparam is not None else None, + input_prec, + ) + + def output_type_cast( + self, + model_ret: Dict[str, np.ndarray], + input_prec: str, + ) -> Dict[str, np.ndarray]: + """Convert the model output to the input prec.""" + do_cast = ( + input_prec + != self.reverse_precision_dict[self.global_np_float_precision] + ) + pp = self.precision_dict[input_prec] + odef = self.model_output_def() + for kk in odef.keys(): + if kk not in model_ret.keys(): + # do not return energy_derv_c if not do_atomic_virial + continue + if check_operation_applied(odef[kk], OutputVariableOperation.REDU): + model_ret[kk] = ( + model_ret[kk].astype(self.global_ener_float_precision) + if model_ret[kk] is not None + else None + ) + elif do_cast: + model_ret[kk] = ( + model_ret[kk].astype(pp) if model_ret[kk] is not None else None + ) + return model_ret + def format_nlist( self, extended_coord: np.ndarray, diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 49368849ca..c87c79f7d4 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -5,6 +5,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + GLOBAL_ENER_FLOAT_PRECISION, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, ModelOutputDef, @@ -30,7 +33,10 @@ def fit_output_to_model_output( atom_axis = -(len(shap) + 1) if vdef.reduciable: kk_redu = get_reduce_name(kk) - model_ret[kk_redu] = np.sum(vv, axis=atom_axis) + # cast to energy prec brefore reduction + model_ret[kk_redu] = np.sum( + vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis + ) if vdef.r_differentiable: kk_derv_r, kk_derv_c = get_deriv_name(kk) # name-holders diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index b6478d297f..3efd3fb046 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -3,6 +3,7 @@ Dict, List, Optional, + Tuple, ) import torch @@ -12,13 +13,18 @@ ) from deepmd.dpmodel.output_def import ( OutputVariableCategory, + OutputVariableOperation, + check_operation_applied, ) from deepmd.pt.model.model.transform_output import ( communicate_extended_output, fit_output_to_model_output, ) -from deepmd.pt.utils import ( - env, +from deepmd.pt.utils.env import ( + GLOBAL_PT_ENER_FLOAT_PRECISION, + GLOBAL_PT_FLOAT_PRECISION, + PRECISION_DICT, + RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, @@ -59,6 +65,10 @@ def __init__( *args, **kwargs, ) + self.precision_dict = PRECISION_DICT + self.reverse_precision_dict = RESERVED_PRECISON_DICT + self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION + self.global_pt_ener_float_precision = GLOBAL_PT_ENER_FLOAT_PRECISION def model_output_def(self): """Get the output def for the model.""" @@ -118,21 +128,22 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - coord = coord.to(env.GLOBAL_PT_FLOAT_PRECISION) - if box is not None: - box = box.to(env.GLOBAL_PT_FLOAT_PRECISION) + cc, bb, fp, ap, input_prec = self.input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + del coord, box, fparam, aparam ( extended_coord, extended_atype, mapping, nlist, ) = extend_input_and_build_neighbor_list( - coord, + cc, atype, self.get_rcut(), self.get_sel(), mixed_types=self.mixed_types(), - box=box, + box=bb, ) model_predict_lower = self.forward_common_lower( extended_coord, @@ -140,8 +151,8 @@ def forward_common( nlist, mapping, do_atomic_virial=do_atomic_virial, - fparam=fparam, - aparam=aparam, + fparam=fp, + aparam=ap, ) model_predict = communicate_extended_output( model_predict_lower, @@ -149,6 +160,7 @@ def forward_common( mapping, do_atomic_virial=do_atomic_virial, ) + model_predict = self.output_type_cast(model_predict, input_prec) return model_predict def forward_common_lower( @@ -189,26 +201,103 @@ def forward_common_lower( the result dict, defined by the `FittingOutputDef`. """ - extended_coord = extended_coord.to(env.GLOBAL_PT_FLOAT_PRECISION) nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.view(nframes, -1, 3) nlist = self.format_nlist(extended_coord, extended_atype, nlist) + cc_ext, _, fp, ap, input_prec = self.input_type_cast( + extended_coord, fparam=fparam, aparam=aparam + ) + del extended_coord, fparam, aparam atomic_ret = self.forward_atomic( - extended_coord, + cc_ext, extended_atype, nlist, mapping=mapping, - fparam=fparam, - aparam=aparam, + fparam=fp, + aparam=ap, ) model_predict = fit_output_to_model_output( atomic_ret, self.fitting_output_def(), - extended_coord, + cc_ext, do_atomic_virial=do_atomic_virial, ) + model_predict = self.output_type_cast(model_predict, input_prec) return model_predict + def input_type_cast( + self, + coord: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + str, + ]: + """Cast the input data to global float type.""" + input_prec = self.reverse_precision_dict[coord.dtype] + ### + ### type checking would not pass jit, convert to coord prec anyway + ### + # for vv, kk in zip([fparam, aparam], ["frame", "atomic"]): + # if vv is not None and self.reverse_precision_dict[vv.dtype] != input_prec: + # log.warning( + # f"type of {kk} parameter {self.reverse_precision_dict[vv.dtype]}" + # " does not match" + # f" that of the coordinate {input_prec}" + # ) + _lst: List[Optional[torch.Tensor]] = [ + vv.to(coord.dtype) if vv is not None else None + for vv in [box, fparam, aparam] + ] + box, fparam, aparam = _lst + if ( + input_prec + == self.reverse_precision_dict[self.global_pt_float_precision] + ): + return coord, box, fparam, aparam, input_prec + else: + pp = self.global_pt_float_precision + return ( + coord.to(pp), + box.to(pp) if box is not None else None, + fparam.to(pp) if fparam is not None else None, + aparam.to(pp) if aparam is not None else None, + input_prec, + ) + + def output_type_cast( + self, + model_ret: Dict[str, torch.Tensor], + input_prec: str, + ) -> Dict[str, torch.Tensor]: + """Convert the model output to the input prec.""" + do_cast = ( + input_prec + != self.reverse_precision_dict[self.global_pt_float_precision] + ) + pp = self.precision_dict[input_prec] + odef = self.model_output_def() + for kk in odef.keys(): + if kk not in model_ret.keys(): + # do not return energy_derv_c if not do_atomic_virial + continue + if check_operation_applied(odef[kk], OutputVariableOperation.REDU): + model_ret[kk] = ( + model_ret[kk].to(self.global_pt_ener_float_precision) + if model_ret[kk] is not None + else None + ) + elif do_cast: + model_ret[kk] = ( + model_ret[kk].to(pp) if model_ret[kk] is not None else None + ) + return model_ret + @torch.jit.export def format_nlist( self, diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 312bb952b5..730e6b29d0 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -14,6 +14,9 @@ get_deriv_name, get_reduce_name, ) +from deepmd.pt.utils import ( + env, +) def atomic_virial_corr( @@ -148,6 +151,7 @@ def fit_output_to_model_output( the model output. """ + redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] @@ -155,7 +159,7 @@ def fit_output_to_model_output( atom_axis = -(len(shap) + 1) if vdef.reduciable: kk_redu = get_reduce_name(kk) - model_ret[kk_redu] = torch.sum(vv, dim=atom_axis) + model_ret[kk_redu] = torch.sum(vv.to(redu_prec), dim=atom_axis) if vdef.r_differentiable: kk_derv_r, kk_derv_c = get_deriv_name(kk) dr, dc = take_deriv( @@ -171,7 +175,7 @@ def fit_output_to_model_output( assert dc is not None model_ret[kk_derv_c] = dc model_ret[kk_derv_c + "_redu"] = torch.sum( - model_ret[kk_derv_c], dim=1 + model_ret[kk_derv_c].to(redu_prec), dim=1 ) return model_ret @@ -186,6 +190,7 @@ def communicate_extended_output( local and ghost (extended) atoms to local atoms. """ + redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION new_ret = {} for kk in model_output_def.keys_outp(): vv = model_ret[kk] @@ -235,7 +240,9 @@ def communicate_extended_output( src=model_ret[kk_derv_c], reduce="sum", ) - new_ret[kk_derv_c + "_redu"] = torch.sum(new_ret[kk_derv_c], dim=1) + new_ret[kk_derv_c + "_redu"] = torch.sum( + new_ret[kk_derv_c].to(redu_prec), dim=1 + ) if not do_atomic_virial: # pop atomic virial, because it is not correctly calculated. new_ret.pop(kk_derv_c) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 7383cf5c49..0b92953255 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -42,6 +42,9 @@ "int64": torch.int64, } GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] +GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ + np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name +] PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION # cannot automatically generated RESERVED_PRECISON_DICT = { @@ -65,6 +68,7 @@ "GLOBAL_ENER_FLOAT_PRECISION", "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_PT_FLOAT_PRECISION", + "GLOBAL_PT_ENER_FLOAT_PRECISION", "DEFAULT_PRECISION", "PRECISION_DICT", "RESERVED_PRECISON_DICT", diff --git a/deepmd/pt/utils/exclude_mask.py b/deepmd/pt/utils/exclude_mask.py index 74b1d8dc41..6df8df8dd0 100644 --- a/deepmd/pt/utils/exclude_mask.py +++ b/deepmd/pt/utils/exclude_mask.py @@ -22,6 +22,13 @@ def __init__( exclude_types: List[int] = [], ): super().__init__() + self.reinit(ntypes, exclude_types) + + def reinit( + self, + ntypes: int, + exclude_types: List[int] = [], + ): self.ntypes = ntypes self.exclude_types = exclude_types self.type_mask = np.array( @@ -62,6 +69,13 @@ def __init__( exclude_types: List[Tuple[int, int]] = [], ): super().__init__() + self.reinit(ntypes, exclude_types) + + def reinit( + self, + ntypes: int, + exclude_types: List[Tuple[int, int]] = [], + ): self.ntypes = ntypes self._exclude_types: Set[Tuple[int, int]] = set() for tt in exclude_types: diff --git a/source/tests/common/dpmodel/case_single_frame_with_nlist.py b/source/tests/common/dpmodel/case_single_frame_with_nlist.py index df4f73efbd..ecdf3590a8 100644 --- a/source/tests/common/dpmodel/case_single_frame_with_nlist.py +++ b/source/tests/common/dpmodel/case_single_frame_with_nlist.py @@ -2,6 +2,27 @@ import numpy as np +class TestCaseSingleFrameWithoutNlist: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nf, self.nt = 1, 2 + self.coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([1, self.nloc * 3]) + self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc]) + self.cell = 2.0 * np.eye(3).reshape([1, 9]) + # sel = [5, 2] + self.sel = [5, 2] + self.rcut = 0.4 + self.rcut_smth = 2.2 + + class TestCaseSingleFrameWithNlist: def setUp(self): # nloc == 3, nall == 4 @@ -17,7 +38,9 @@ def setUp(self): ], dtype=np.float64, ).reshape([1, self.nall, 3]) + self.coord = self.coord_ext[:, : self.nloc, :] self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) + self.atype = self.atype_ext[:, : self.nloc] # sel = [5, 2] self.sel = [5, 2] self.nlist = np.array( diff --git a/source/tests/common/dpmodel/test_dp_model.py b/source/tests/common/dpmodel/test_dp_model.py index b982c9c2b5..c3de1f4cdf 100644 --- a/source/tests/common/dpmodel/test_dp_model.py +++ b/source/tests/common/dpmodel/test_dp_model.py @@ -15,10 +15,11 @@ from .case_single_frame_with_nlist import ( TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithoutNlist, ) -class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithNlist): +class TestDPModelLower(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) @@ -47,3 +48,99 @@ def test_self_consistency( np.testing.assert_allclose(ret0["energy"], ret1["energy"]) np.testing.assert_allclose(ret0["energy_redu"], ret1["energy_redu"]) + + def test_prec_consistency(self): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + nfp, nap = 2, 3 + type_map = ["foo", "bar"] + # fparam, aparam are converted to coordinate precision by model + fparam = rng.normal(size=[self.nf, nfp]) + aparam = rng.normal(size=[self.nf, nloc, nap]) + + md1 = DPModel(ds, ft, type_map=type_map) + + args64 = [self.coord_ext, self.atype_ext, self.nlist] + args64[0] = args64[0].astype(np.float64) + args32 = [self.coord_ext, self.atype_ext, self.nlist] + args32[0] = args32[0].astype(np.float32) + + model_l_ret_64 = md1.call_lower(*args64, fparam=fparam, aparam=aparam) + model_l_ret_32 = md1.call_lower(*args32, fparam=fparam, aparam=aparam) + + for ii in model_l_ret_32.keys(): + if model_l_ret_32[ii] is None: + continue + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, np.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, np.float32) + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + np.testing.assert_allclose( + model_l_ret_32[ii], + model_l_ret_64[ii], + ) + + +class TestDPModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist): + def setUp(self): + TestCaseSingleFrameWithoutNlist.setUp(self) + + def test_prec_consistency(self): + rng = np.random.default_rng() + nf, nloc = self.atype.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + nfp, nap = 2, 3 + type_map = ["foo", "bar"] + # fparam, aparam are converted to coordinate precision by model + fparam = rng.normal(size=[self.nf, nfp]) + aparam = rng.normal(size=[self.nf, nloc, nap]) + + md1 = DPModel(ds, ft, type_map=type_map) + + args64 = [self.coord, self.atype, self.cell] + args64[0] = args64[0].astype(np.float64) + args64[2] = args64[2].astype(np.float64) + args32 = [self.coord, self.atype, self.cell] + args32[0] = args32[0].astype(np.float32) + args32[2] = args32[2].astype(np.float32) + + model_l_ret_64 = md1.call(*args64, fparam=fparam, aparam=aparam) + model_l_ret_32 = md1.call(*args32, fparam=fparam, aparam=aparam) + + for ii in model_l_ret_32.keys(): + if model_l_ret_32[ii] is None: + continue + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, np.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, np.float32) + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + np.testing.assert_allclose( + model_l_ret_32[ii], + model_l_ret_64[ii], + ) diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 0a16d4672c..840ba284e2 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -186,6 +186,53 @@ def test_dp_consistency_nopbc(self): to_numpy_array(ret1["energy_redu"]), ) + def test_prec_consistency(self): + rng = np.random.default_rng() + nf, nloc = self.atype.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + nfp, nap = 2, 3 + type_map = ["foo", "bar"] + fparam = rng.normal(size=[self.nf, nfp]) + aparam = rng.normal(size=[self.nf, nloc, nap]) + + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + + args64 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + args64[0] = args64[0].to(torch.float64) + args64[2] = args64[2].to(torch.float64) + args32 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] + args32[0] = args32[0].to(torch.float32) + args32[2] = args32[2].to(torch.float32) + # fparam, aparam are converted to coordinate precision by model + fparam = to_torch_tensor(fparam) + aparam = to_torch_tensor(aparam) + + model_l_ret_64 = md1.forward_common(*args64, fparam=fparam, aparam=aparam) + model_l_ret_32 = md1.forward_common(*args32, fparam=fparam, aparam=aparam) + + for ii in model_l_ret_32.keys(): + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, torch.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) + self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) + np.testing.assert_allclose( + to_numpy_array(model_l_ret_32[ii]), + to_numpy_array(model_l_ret_64[ii]), + ) + class TestDPModelLower(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): @@ -269,6 +316,55 @@ def test_dp_consistency(self): to_numpy_array(ret1["energy_redu"]), ) + def test_prec_consistency(self): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + nfp, nap = 2, 3 + type_map = ["foo", "bar"] + fparam = rng.normal(size=[self.nf, nfp]) + aparam = rng.normal(size=[self.nf, nloc, nap]) + + md0 = DPDPModel(ds, ft, type_map=type_map) + md1 = DPModel.deserialize(md0.serialize()).to(env.DEVICE) + + args64 = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + args64[0] = args64[0].to(torch.float64) + args32 = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + args32[0] = args32[0].to(torch.float32) + # fparam, aparam are converted to coordinate precision by model + fparam = to_torch_tensor(fparam) + aparam = to_torch_tensor(aparam) + + model_l_ret_64 = md1.forward_common_lower(*args64, fparam=fparam, aparam=aparam) + model_l_ret_32 = md1.forward_common_lower(*args32, fparam=fparam, aparam=aparam) + + for ii in model_l_ret_32.keys(): + if ii[-4:] == "redu": + self.assertEqual(model_l_ret_32[ii].dtype, torch.float64) + else: + self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) + self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) + np.testing.assert_allclose( + to_numpy_array(model_l_ret_32[ii]), + to_numpy_array(model_l_ret_64[ii]), + ) + def test_jit(self): nf, nloc, nnei = self.nlist.shape ds = DescrptSeA(