From 677d936d8cc79341c7679e31bf8891ecb52e7cb8 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Sat, 3 Feb 2024 05:05:08 +0800 Subject: [PATCH] fix bug of output def: the reduced virial is not defined. (#3219) Co-authored-by: Han Wang --- deepmd/dpmodel/output_def.py | 21 +++++++++++++++------ source/tests/common/test_output_def.py | 7 +++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/deepmd/dpmodel/output_def.py b/deepmd/dpmodel/output_def.py index 583f88491e..6cd83fcf28 100644 --- a/deepmd/dpmodel/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -147,6 +147,8 @@ def __init__( self.differentiable = differentiable if not self.reduciable and self.differentiable: raise ValueError("only reduciable variable are differentiable") + if self.reduciable and not self.atomic: + raise ValueError("only reduciable variable should be atomic") class FittingOutputDef: @@ -201,14 +203,16 @@ def __init__( fit_defs: FittingOutputDef, ): self.def_outp = fit_defs - self.def_redu = do_reduce(self.def_outp) - self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp) + self.def_redu = do_reduce(self.def_outp.get_data()) + self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp.get_data()) + self.def_derv_c_redu = do_reduce(self.def_derv_c) self.var_defs: Dict[str, OutputVariableDef] = {} for ii in [ self.def_outp.get_data(), self.def_redu, self.def_derv_c, self.def_derv_r, + self.def_derv_c_redu, ]: self.var_defs.update(ii) @@ -239,6 +243,9 @@ def keys_derv_r(self): def keys_derv_c(self): return self.def_derv_c.keys() + def keys_derv_c_redu(self): + return self.def_derv_c_redu.keys() + def get_reduce_name(name: str) -> str: return name + "_redu" @@ -249,10 +256,10 @@ def get_deriv_name(name: str) -> Tuple[str, str]: def do_reduce( - def_outp: FittingOutputDef, + def_outp_data: Dict[str, OutputVariableDef], ) -> Dict[str, OutputVariableDef]: def_redu: Dict[str, OutputVariableDef] = {} - for kk, vv in def_outp.get_data().items(): + for kk, vv in def_outp_data.items(): if vv.reduciable: rk = get_reduce_name(kk) def_redu[rk] = OutputVariableDef( @@ -262,11 +269,11 @@ def do_reduce( def do_derivative( - def_outp: FittingOutputDef, + def_outp_data: Dict[str, OutputVariableDef], ) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]: def_derv_r: Dict[str, OutputVariableDef] = {} def_derv_c: Dict[str, OutputVariableDef] = {} - for kk, vv in def_outp.get_data().items(): + for kk, vv in def_outp_data.items(): if vv.differentiable: rkr, rkc = get_deriv_name(kk) def_derv_r[rkr] = OutputVariableDef( @@ -274,11 +281,13 @@ def do_derivative( vv.shape + [3], # noqa: RUF005 reduciable=False, differentiable=False, + atomic=True, ) def_derv_c[rkc] = OutputVariableDef( rkc, vv.shape + [3, 3], # noqa: RUF005 reduciable=True, differentiable=False, + atomic=True, ) return def_derv_r, def_derv_c diff --git a/source/tests/common/test_output_def.py b/source/tests/common/test_output_def.py index d0cf822247..aaabdc0ba6 100644 --- a/source/tests/common/test_output_def.py +++ b/source/tests/common/test_output_def.py @@ -70,6 +70,7 @@ def test_model_output_def(self): "energy_redu", "energy_derv_r", "energy_derv_c", + "energy_derv_c_redu", "dos_redu", ] self.assertEqual( @@ -93,6 +94,7 @@ def test_model_output_def(self): self.assertEqual(md["energy_redu"].shape, [1]) self.assertEqual(md["energy_derv_r"].shape, [1, 3]) self.assertEqual(md["energy_derv_c"].shape, [1, 3, 3]) + self.assertEqual(md["energy_derv_c_redu"].shape, [1, 3, 3]) # atomic self.assertEqual(md["energy"].atomic, True) self.assertEqual(md["dos"].atomic, True) @@ -100,11 +102,16 @@ def test_model_output_def(self): self.assertEqual(md["energy_redu"].atomic, False) self.assertEqual(md["energy_derv_r"].atomic, True) self.assertEqual(md["energy_derv_c"].atomic, True) + self.assertEqual(md["energy_derv_c_redu"].atomic, False) def test_raise_no_redu_deriv(self): with self.assertRaises(ValueError) as context: (OutputVariableDef("energy", [1], False, True),) + def test_raise_redu_not_atomic(self): + with self.assertRaises(ValueError) as context: + (OutputVariableDef("energy", [1], True, False, atomic=False),) + def test_model_decorator(self): nf = 2 nloc = 3