Skip to content

Commit

Permalink
fix: fix errors for zero atom inputs (deepmodeling#4005)
Browse files Browse the repository at this point in the history
`reshape((0, -1))` is not allowed 

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Improved tensor reshaping logic across multiple models to enhance
handling of varying input shapes.
- Added a new test case to validate functionality when no atoms are
present.

- **Bug Fixes**
- Enhanced error handling for empty coordinate inputs to improve
function stability.

- **Documentation**
- Updated comments and structure for clarity in tensor manipulations
across various models.

- **Refactor**
- Introduced explicit dimension definitions in tensor reshaping to avoid
reliance on automatic inference.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and mtaillefumier committed Sep 18, 2024
1 parent 1a877d6 commit 60ae6c8
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 36 deletions.
3 changes: 2 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def forward_common_atomic(

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = np.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
Expand Down
20 changes: 12 additions & 8 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,9 @@ def call(
)
# nf x nloc x (ng x ng1 + tebd_dim)
if self.concat_output_tebd:
grrg = np.concatenate([grrg, atype_embd.reshape(nf, nloc, -1)], axis=-1)
grrg = np.concatenate(
[grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1
)
return grrg, rot_mat, None, None, sw

def serialize(self) -> dict:
Expand Down Expand Up @@ -834,7 +836,8 @@ def cal_g(
embedding_idx,
):
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
shape2 = np.prod(ss.shape[2:])
ss = ss.reshape(nfnl, nnei, shape2)
# nfnl x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
return gg
Expand All @@ -846,7 +849,8 @@ def cal_g_strip(
):
assert self.embeddings_strip is not None
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
shape2 = np.prod(ss.shape[2:])
ss = ss.reshape(nfnl, nnei, shape2)
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg
Expand Down Expand Up @@ -875,7 +879,7 @@ def call(
# nfnl x nnei x 1
sw = sw.reshape(nf * nloc, nnei, 1)
# nfnl x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, -1)
atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, self.tebd_dim)
# nfnl x nnei x tebd_dim
atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1))
# nfnl x nnei
Expand Down Expand Up @@ -941,10 +945,10 @@ def call(
GLOBAL_NP_FLOAT_PRECISION
)
return (
grrg.reshape(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.reshape(-1, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.reshape(-1, nloc, self.nnei, 4)[..., 1:],
gr[..., 1:].reshape(-1, nloc, self.filter_neuron[-1], 3),
grrg.reshape(nf, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.reshape(nf, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.reshape(nf, nloc, self.nnei, 4)[..., 1:],
gr[..., 1:].reshape(nf, nloc, self.filter_neuron[-1], 3),
sw,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def call(
h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon)
# (nf x nloc) x ng2 x 3
rot_mat = np.transpose(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw
return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def call(

res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = res.reshape(nf, nloc, -1).astype(GLOBAL_NP_FLOAT_PRECISION)
res = res.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
return res, None, None, None, ww

def serialize(self) -> dict:
Expand Down
11 changes: 7 additions & 4 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,15 @@ def build_neighbor_list(
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = np.max(coord) + 2.0 * rcut
if coord.size > 0:
xmax = np.max(coord) + 2.0 * rcut
else:
xmax = 2.0 * rcut
# nf x nall
is_vir = atype < 0
coord1 = np.where(is_vir[:, :, None], xmax, coord.reshape(-1, nall, 3)).reshape(
-1, nall * 3
)
coord1 = np.where(
is_vir[:, :, None], xmax, coord.reshape(batch_size, nall, 3)
).reshape(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,11 @@ def forward_common_atomic(

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = 1
for ss in out_shape[2:]:
out_shape2 *= ss
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).view(out_shape)
ret_dict["mask"] = atom_mask
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def forward(
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw
return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw

def compute_input_stats(
self,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def forward(
- `torch.Tensor`: descriptor matrix with shape [nframes, natoms[0]*self.filter_neuron[-1]*self.axis_neuron].
"""
del extended_atype_embd, mapping
nf = nlist.shape[0]
nloc = nlist.shape[1]
atype = extended_atype[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
Expand Down Expand Up @@ -657,7 +658,7 @@ def forward(
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
Expand Down Expand Up @@ -698,8 +699,8 @@ def forward(
result = torch.matmul(
xyz_scatter_1, xyz_scatter_2
) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron]
result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ def forward(
xyz_scatter_1, xyz_scatter_2
) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron]
return (
result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.view(-1, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:],
rot_mat.view(-1, nloc, self.filter_neuron[-1], 3),
result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron),
gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]),
dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:],
rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3),
sw,
)

Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def forward(
"""
del mapping
nf = nlist.shape[0]
nloc = nlist.shape[1]
atype = atype_ext[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
Expand All @@ -367,7 +368,7 @@ def forward(
)

# nfnl x nnei
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1)
exclude_mask = self.emask(nlist, atype_ext).view(nfnl, self.nnei)
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
Expand All @@ -381,7 +382,7 @@ def forward(

res_rescale = 1.0 / 5.0
result = xyz_scatter * res_rescale
result = result.view(-1, nloc, self.filter_neuron[-1])
result = result.view(nf, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ def forward(
"""
del extended_atype_embd, mapping
nf = nlist.shape[0]
nloc = nlist.shape[1]
atype = extended_atype[:, :nloc]
dmatrix, diff, sw = prod_env_mat(
Expand All @@ -687,7 +688,7 @@ def forward(
device=extended_coord.device,
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1)
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)
for embedding_idx, ll in enumerate(self.filter_layers.networks):
ti = embedding_idx % self.ntypes
nei_type_j = self.sel[ti]
Expand All @@ -714,7 +715,7 @@ def forward(
res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j))
result += res_ij
# xyz_scatter /= (self.nnei * self.nnei)
result = result.view(-1, nloc, self.filter_neuron[-1])
result = result.view(nf, nloc, self.filter_neuron[-1])
return (
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
None,
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def process_spin_output(
out_real, out_mag = torch.split(out_tensor, [nloc, nloc], dim=1)
if add_mag:
out_real = out_real + out_mag
out_mag = (out_mag.view([nframes, nloc, -1]) * atomic_mask).view(out_mag.shape)
shape2 = 1
for ss in out_real.shape[2:]:
shape2 *= ss
out_mag = (out_mag.view([nframes, nloc, shape2]) * atomic_mask).view(
out_mag.shape
)
return out_real, out_mag, atomic_mask > 0.0

def process_spin_output_lower(
Expand Down Expand Up @@ -162,8 +167,11 @@ def process_spin_output_lower(
)
if add_mag:
extended_out_real = extended_out_real + extended_out_mag
shape2 = 1
for ss in extended_out_tensor.shape[2:]:
shape2 *= ss
extended_out_mag = (
extended_out_mag.view([nframes, nall, -1]) * atomic_mask
extended_out_mag.view([nframes, nall, shape2]) * atomic_mask
).view(extended_out_mag.shape)
return extended_out_real, extended_out_mag, atomic_mask > 0.0

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def forward(
# (nframes * nloc, 1, m1)
out = out.view(-1, 1, self.embedding_width)
# (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, -1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3)
# (nframes, nloc, 3)
out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3)
return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)}
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def forward(
self.var_name
]
out = out * (self.scale.to(atype.device))[atype]
gr = gr.view(nframes * nloc, -1, 3) # (nframes * nloc, m1, 3)
gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3)

if self.fit_diag:
out = out.reshape(-1, self.embedding_width)
Expand Down
11 changes: 7 additions & 4 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,15 @@ def build_neighbor_list(
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = torch.max(coord) + 2.0 * rcut
if coord.numel() > 0:
xmax = torch.max(coord) + 2.0 * rcut
else:
xmax = torch.zeros(1, dtype=coord.dtype, device=coord.device) + 2.0 * rcut
# nf x nall
is_vir = atype < 0
coord1 = torch.where(is_vir[:, :, None], xmax, coord.view(-1, nall, 3)).view(
-1, nall * 3
)
coord1 = torch.where(
is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
).view(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down
103 changes: 103 additions & 0 deletions source/tests/universal/common/cases/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,109 @@ def test_forward(self):
continue
np.testing.assert_allclose(rr1, rr2, atol=aprec)

def test_zero_forward(self):
test_spin = getattr(self, "test_spin", False)
nf = 1
natoms = 0
aprec = (
0
if self.aprec_dict.get("test_forward", None) is None
else self.aprec_dict["test_forward"]
)
rng = np.random.default_rng(GLOBAL_SEED)
coord = np.zeros((nf, 0, 3), dtype=np.float64)
atype = np.zeros([nf, 0], dtype=int)
spin = np.zeros([nf, 0], dtype=np.float64)
cell = 6.0 * np.eye(3, dtype=np.float64).reshape([nf, 9])
coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
coord,
atype,
self.expected_rcut + 1.0 if test_spin else self.expected_rcut,
self.expected_sel,
mixed_types=self.module.mixed_types(),
box=cell,
)
spin_ext = np.take_along_axis(
spin.reshape(nf, -1, 3),
np.repeat(np.expand_dims(mapping, axis=-1), 3, axis=-1),
axis=1,
)
aparam = None
fparam = None
if self.module.get_dim_aparam() > 0:
aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
if self.module.get_dim_fparam() > 0:
fparam = rng.random([nf, self.module.get_dim_fparam()])
ret = []
ret_lower = []
for module in self.modules_to_test:
module = self.forward_wrapper(module)
input_dict = {
"coord": coord,
"atype": atype,
"box": cell,
"aparam": aparam,
"fparam": fparam,
}
if test_spin:
input_dict["spin"] = spin
ret.append(module(**input_dict))

input_dict_lower = {
"extended_coord": coord_ext,
"extended_atype": atype_ext,
"nlist": nlist,
"mapping": mapping,
"aparam": aparam,
"fparam": fparam,
}
if test_spin:
input_dict_lower["extended_spin"] = spin_ext

ret_lower.append(module.forward_lower(**input_dict_lower))
for kk in ret[0]:
subret = []
for rr in ret:
if rr is not None:
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
for kk in ret_lower[0]:
subret = []
for rr in ret_lower:
if rr is not None:
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
same_keys = set(ret[0].keys()) & set(ret_lower[0].keys())
self.assertTrue(same_keys)
for key in same_keys:
for rr in ret:
if rr[key] is not None:
rr1 = rr[key]
break
else:
continue
for rr in ret_lower:
if rr[key] is not None:
rr2 = rr[key]
break
else:
continue
np.testing.assert_allclose(rr1, rr2, atol=aprec)

@unittest.skipIf(TEST_DEVICE != "cpu", "Only test on CPU.")
def test_permutation(self):
"""Test permutation."""
Expand Down

0 comments on commit 60ae6c8

Please sign in to comment.