diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 3af03bda97..70b7228440 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -101,6 +101,14 @@ def iter( dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) + if self.last_dim == 4: + radial_only = False + elif self.last_dim == 1: + radial_only = True + else: + raise ValueError( + "last_dim should be 1 for raial-only or 4 for full descriptor." + ) for system in data: coord, atype, box, natoms = ( system["coord"], @@ -130,6 +138,7 @@ def iter( self.descriptor.get_rcut(), # TODO: export rcut_smth from DescriptorBlock self.descriptor.rcut_smth, + radial_only, ) # reshape to nframes * nloc at the atom level, # so nframes/mixed_type do not matter diff --git a/source/tests/pt/model/test_descriptor_se_r.py b/source/tests/pt/model/test_descriptor_se_r.py index c999f06863..5b8b6c9251 100644 --- a/source/tests/pt/model/test_descriptor_se_r.py +++ b/source/tests/pt/model/test_descriptor_se_r.py @@ -15,6 +15,9 @@ from deepmd.pt.utils.env import ( PRECISION_DICT, ) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, +) from .test_env_mat import ( TestCaseSingleFrameWithNlist, @@ -103,13 +106,61 @@ def test_consistency( err_msg=err_msg, ) + def test_load_stat(self): + rng = np.random.default_rng() + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + + # sea new impl + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + ) + dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + dd1.compute_input_stats( + [ + { + "r0": None, + "coord": torch.from_numpy(self.coord_ext) + .reshape(-1, self.nall, 3) + .to(env.DEVICE), + "atype": torch.from_numpy(self.atype_ext).to(env.DEVICE), + "box": None, + "natoms": self.nall, + } + ] + ) + + with self.assertRaises(ValueError) as cm: + ev = EnvMatStatSe(dd1) + ev.last_dim = 3 + ev.load_or_compute_stats([]) + self.assertEqual( + "last_dim should be 1 for raial-only or 4 for full descriptor.", + str(cm.exception), + ) + def test_jit( self, ): rng = np.random.default_rng() _, _, nnei = self.nlist.shape - davg = rng.normal(size=(self.nt, nnei, 4)) - dstd = rng.normal(size=(self.nt, nnei, 4)) + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) dstd = 0.1 + np.abs(dstd) for idt, prec in itertools.product(