From fa94e8e3383eb27e8eb35fcce698049f59ec537b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 22 Jul 2024 19:03:51 -0400 Subject: [PATCH 1/3] fix: support zero atoms input Signed-off-by: Jinzhe Zeng --- .../dpmodel/atomic_model/base_atomic_model.py | 3 +- deepmd/dpmodel/descriptor/dpa1.py | 20 ++-- deepmd/dpmodel/descriptor/repformers.py | 2 +- deepmd/dpmodel/descriptor/se_r.py | 2 +- deepmd/dpmodel/utils/nlist.py | 11 +- .../model/atomic_model/base_atomic_model.py | 5 +- deepmd/pt/model/descriptor/repformers.py | 2 +- deepmd/pt/model/descriptor/se_a.py | 7 +- deepmd/pt/model/descriptor/se_atten.py | 8 +- deepmd/pt/model/descriptor/se_r.py | 5 +- deepmd/pt/model/descriptor/se_t.py | 5 +- deepmd/pt/model/model/spin_model.py | 7 +- deepmd/pt/model/task/dipole.py | 2 +- deepmd/pt/model/task/polarizability.py | 2 +- deepmd/pt/utils/nlist.py | 11 +- .../universal/common/cases/model/utils.py | 103 ++++++++++++++++++ 16 files changed, 160 insertions(+), 35 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index bb884ef45c..3012e46c18 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -200,8 +200,9 @@ def forward_common_atomic( for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape + out_shape2 = np.prod(out_shape[2:]) ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) + ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] ).reshape(out_shape) ret_dict["mask"] = atom_mask diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 4eae05560f..360df6a591 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -487,7 +487,9 @@ def call( ) # nf x nloc x (ng x ng1 + tebd_dim) if self.concat_output_tebd: - grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1) + grrg = np.concatenate( + [grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1 + ) return grrg, rot_mat, None, None, sw def serialize(self) -> dict: @@ -834,7 +836,8 @@ def cal_g( embedding_idx, ): nfnl, nnei = ss.shape[0:2] - ss = ss.reshape(nfnl, nnei, -1) + shape2 = np.prod(ss.shape[2:]) + ss = ss.reshape(nfnl, nnei, shape2) # nfnl x nnei x ng gg = self.embeddings[embedding_idx].call(ss) return gg @@ -846,7 +849,8 @@ def cal_g_strip( ): assert self.embeddings_strip is not None nfnl, nnei = ss.shape[0:2] - ss = ss.reshape(nfnl, nnei, -1) + shape2 = np.prod(ss.shape[2:]) + ss = ss.reshape(nfnl, nnei, shape2) # nfnl x nnei x ng gg = self.embeddings_strip[embedding_idx].call(ss) return gg @@ -875,7 +879,7 @@ def call( # nfnl x nnei x 1 sw = sw.reshape(nf * nloc, nnei, 1) # nfnl x tebd_dim - atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1) + atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, self.tebd_dim) # nfnl x nnei x tebd_dim atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1)) # nfnl x nnei @@ -941,10 +945,10 @@ def call( GLOBAL_NP_FLOAT_PRECISION ) return ( - grrg.reshape(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), - gg.reshape(-1, nloc, self.nnei, self.filter_neuron[-1]), - dmatrix.reshape(-1, nloc, self.nnei, 4)[..., 1:], - gr[..., 1:].reshape(-1, nloc, self.filter_neuron[-1], 3), + grrg.reshape(nf, nloc, self.filter_neuron[-1] * self.axis_neuron), + gg.reshape(nf, nloc, self.nnei, self.filter_neuron[-1]), + dmatrix.reshape(nf, nloc, self.nnei, 4)[..., 1:], + gr[..., 1:].reshape(nf, nloc, self.filter_neuron[-1], 3), sw, ) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 67c72e8d31..af286a35e7 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -395,7 +395,7 @@ def call( h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) # (nf x nloc) x ng2 x 3 rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) - return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw + return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 20a6fe49dd..4b89e1dd90 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -341,7 +341,7 @@ def call( res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale - res = res.reshape(nf, nloc, -1).astype(GLOBAL_NP_FLOAT_PRECISION) + res = res.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION) return res, None, None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 018f50f1a5..9f04678fdb 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -95,12 +95,15 @@ def build_neighbor_list( nall = coord.shape[1] // 3 # fill virtual atoms with large coords so they are not neighbors of any # real atom. - xmax = np.max(coord) + 2.0 * rcut + if coord.size > 0: + xmax = np.max(coord) + 2.0 * rcut + else: + xmax = 2.0 * rcut # nf x nall is_vir = atype < 0 - coord1 = np.where(is_vir[:, :, None], xmax, coord.reshape(-1, nall, 3)).reshape( - -1, nall * 3 - ) + coord1 = np.where( + is_vir[:, :, None], xmax, coord.reshape(batch_size, nall, 3) + ).reshape(batch_size, nall * 3) if isinstance(sel, int): sel = [sel] nsel = sum(sel) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 6a42393310..4ec78ec94f 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -256,8 +256,11 @@ def forward_common_atomic( for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape + out_shape2 = 1 + for ss in out_shape[2:]: + out_shape2 *= ss ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) + ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] ).view(out_shape) ret_dict["mask"] = atom_mask diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 8653d79140..2eb225f120 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -505,7 +505,7 @@ def forward( # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) - return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw + return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw def compute_input_stats( self, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 8e51b03fc2..f9b8e88766 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -617,6 +617,7 @@ def forward( - `torch.Tensor`: descriptor matrix with shape [nframes, natoms[0]*self.filter_neuron[-1]*self.axis_neuron]. """ del extended_atype_embd, mapping + nf = nlist.shape[0] nloc = nlist.shape[1] atype = extended_atype[:, :nloc] dmatrix, diff, sw = prod_env_mat( @@ -657,7 +658,7 @@ def forward( device=extended_coord.device, ) # nfnl x nnei - exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1) + exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) for embedding_idx, ll in enumerate(self.filter_layers.networks): if self.type_one_side: ii = embedding_idx @@ -698,8 +699,8 @@ def forward( result = torch.matmul( xyz_scatter_1, xyz_scatter_2 ) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron] - result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron) - rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 + result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron) + rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index a30869f24a..44a144ea17 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -580,10 +580,10 @@ def forward( xyz_scatter_1, xyz_scatter_2 ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] return ( - result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), - gg.view(-1, nloc, self.nnei, self.filter_neuron[-1]), - dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:], - rot_mat.view(-1, nloc, self.filter_neuron[-1], 3), + result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron), + gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]), + dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:], + rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3), sw, ) diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index bdf0bb8c17..9c9556f93e 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -343,6 +343,7 @@ def forward( """ del mapping + nf = nlist.shape[0] nloc = nlist.shape[1] atype = atype_ext[:, :nloc] dmatrix, diff, sw = prod_env_mat( @@ -367,7 +368,7 @@ def forward( ) # nfnl x nnei - exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1) + exclude_mask = self.emask(nlist, atype_ext).view(nfnl, self.nnei) for ii, ll in enumerate(self.filter_layers.networks): # nfnl x nt mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] @@ -381,7 +382,7 @@ def forward( res_rescale = 1.0 / 5.0 result = xyz_scatter * res_rescale - result = result.view(-1, nloc, self.filter_neuron[-1]) + result = result.view(nf, nloc, self.filter_neuron[-1]) return ( result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), None, diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 2fe5c16059..d84bae603b 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -665,6 +665,7 @@ def forward( """ del extended_atype_embd, mapping + nf = nlist.shape[0] nloc = nlist.shape[1] atype = extended_atype[:, :nloc] dmatrix, diff, sw = prod_env_mat( @@ -687,7 +688,7 @@ def forward( device=extended_coord.device, ) # nfnl x nnei - exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1) + exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei) for embedding_idx, ll in enumerate(self.filter_layers.networks): ti = embedding_idx % self.ntypes nei_type_j = self.sel[ti] @@ -714,7 +715,7 @@ def forward( res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j)) result += res_ij # xyz_scatter /= (self.nnei * self.nnei) - result = result.view(-1, nloc, self.filter_neuron[-1]) + result = result.view(nf, nloc, self.filter_neuron[-1]) return ( result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), None, diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 72e6797ea8..79eb4dd8f4 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -121,7 +121,12 @@ def process_spin_output( out_real, out_mag = torch.split(out_tensor, [nloc, nloc], dim=1) if add_mag: out_real = out_real + out_mag - out_mag = (out_mag.view([nframes, nloc, -1]) * atomic_mask).view(out_mag.shape) + shape2 = 1 + for ss in out_real.shape[2:]: + shape2 *= ss + out_mag = (out_mag.view([nframes, nloc, shape2]) * atomic_mask).view( + out_mag.shape + ) return out_real, out_mag, atomic_mask > 0.0 def process_spin_output_lower( diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 52636d8d95..30c5a341a7 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -191,7 +191,7 @@ def forward( # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) # (nframes * nloc, m1, 3) - gr = gr.view(nframes * nloc, -1, 3) + gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3) return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 4bf4e3c1c5..7345fa296c 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -242,7 +242,7 @@ def forward( self.var_name ] out = out * (self.scale.to(atype.device))[atype] - gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3) + gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) if self.fit_diag: out = out.reshape(-1, self.embedding_width) diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index a24a5aef72..0eac7bb52a 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -99,12 +99,15 @@ def build_neighbor_list( nall = coord.shape[1] // 3 # fill virtual atoms with large coords so they are not neighbors of any # real atom. - xmax = torch.max(coord) + 2.0 * rcut + if coord.numel() > 0: + xmax = torch.max(coord) + 2.0 * rcut + else: + xmax = torch.zeros(1, dtype=coord.dtype, device=coord.device) + 2.0 * rcut # nf x nall is_vir = atype < 0 - coord1 = torch.where(is_vir[:, :, None], xmax, coord.view(-1, nall, 3)).view( - -1, nall * 3 - ) + coord1 = torch.where( + is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3) + ).view(batch_size, nall * 3) if isinstance(sel, int): sel = [sel] nsel = sum(sel) diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index 68a5498d32..d35dffecc8 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -221,6 +221,109 @@ def test_forward(self): continue np.testing.assert_allclose(rr1, rr2, atol=aprec) + def test_zero_forward(self): + test_spin = getattr(self, "test_spin", False) + nf = 1 + natoms = 0 + aprec = ( + 0 + if self.aprec_dict.get("test_forward", None) is None + else self.aprec_dict["test_forward"] + ) + rng = np.random.default_rng(GLOBAL_SEED) + coord = np.zeros((nf, 0, 3), dtype=np.float64) + atype = np.zeros([nf, 0], dtype=int) + spin = np.zeros([nf, 0], dtype=np.float64) + cell = 6.0 * np.eye(3, dtype=np.float64).reshape([nf, 9]) + coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( + coord, + atype, + self.expected_rcut + 1.0 if test_spin else self.expected_rcut, + self.expected_sel, + mixed_types=self.module.mixed_types(), + box=cell, + ) + spin_ext = np.take_along_axis( + spin.reshape(nf, -1, 3), + np.repeat(np.expand_dims(mapping, axis=-1), 3, axis=-1), + axis=1, + ) + aparam = None + fparam = None + if self.module.get_dim_aparam() > 0: + aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) + if self.module.get_dim_fparam() > 0: + fparam = rng.random([nf, self.module.get_dim_fparam()]) + ret = [] + ret_lower = [] + for module in self.modules_to_test: + module = self.forward_wrapper(module) + input_dict = { + "coord": coord, + "atype": atype, + "box": cell, + "aparam": aparam, + "fparam": fparam, + } + if test_spin: + input_dict["spin"] = spin + ret.append(module(**input_dict)) + + input_dict_lower = { + "extended_coord": coord_ext, + "extended_atype": atype_ext, + "nlist": nlist, + "mapping": mapping, + "aparam": aparam, + "fparam": fparam, + } + if test_spin: + input_dict_lower["extended_spin"] = spin_ext + + ret_lower.append(module.forward_lower(**input_dict_lower)) + for kk in ret[0]: + subret = [] + for rr in ret: + if rr is not None: + subret.append(rr[kk]) + if len(subret): + for ii, rr in enumerate(subret[1:]): + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) + for kk in ret_lower[0].keys(): + subret = [] + for rr in ret_lower: + if rr is not None: + subret.append(rr[kk]) + if len(subret): + for ii, rr in enumerate(subret[1:]): + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) + same_keys = set(ret[0].keys()) & set(ret_lower[0].keys()) + self.assertTrue(same_keys) + for key in same_keys: + for rr in ret: + if rr[key] is not None: + rr1 = rr[key] + break + else: + continue + for rr in ret_lower: + if rr[key] is not None: + rr2 = rr[key] + break + else: + continue + np.testing.assert_allclose(rr1, rr2, atol=aprec) + @unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.") def test_permutation(self): """Test permutation.""" From 29c1b58e122408329d041c5e5f2fd10021f625f6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 22 Jul 2024 19:38:12 -0400 Subject: [PATCH 2/3] fix shape Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/model/spin_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 79eb4dd8f4..551c0b86b2 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -167,8 +167,11 @@ def process_spin_output_lower( ) if add_mag: extended_out_real = extended_out_real + extended_out_mag + shape2 = 1 + for ss in extended_out_tensor.shape[2:]: + shape2 *= ss extended_out_mag = ( - extended_out_mag.view([nframes, nall, -1]) * atomic_mask + extended_out_mag.view([nframes, nall, shape2]) * atomic_mask ).view(extended_out_mag.shape) return extended_out_real, extended_out_mag, atomic_mask > 0.0 From 747e1c91289919a5b5503c4fccd9677ff019ef6c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 22 Jul 2024 22:42:31 -0400 Subject: [PATCH 3/3] Update source/tests/universal/common/cases/model/utils.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- source/tests/universal/common/cases/model/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index d35dffecc8..2be55a6337 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -294,7 +294,7 @@ def test_zero_forward(self): np.testing.assert_allclose( subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" ) - for kk in ret_lower[0].keys(): + for kk in ret_lower[0]: subret = [] for rr in ret_lower: if rr is not None: