Skip to content

Commit

Permalink
throw errros if rmin is no less than rmax (#3393)
Browse files Browse the repository at this point in the history
when rmin==rmax, the previous implementation of compute_smooth_weight
will give all nan. In theory, it should not happen.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Mar 5, 2024
1 parent 268a0fc commit b0171ce
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def compute_smooth_weight(
rmax: float,
):
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = np.logical_not(np.logical_or(min_mask, max_mask))
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def build_neighbor_list(

def compute_smooth_weight(distance, rmin: float, rmax: float):
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = torch.logical_not(torch.logical_or(min_mask, max_mask))
Expand Down
8 changes: 4 additions & 4 deletions source/tests/common/dpmodel/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def test_pairwise(self, mock_loadtxt):
nlist = np.array([[[1], [-1]]])

ds = DescrptSeA(
rcut=0.3,
rcut_smth=0.4,
rcut_smth=0.3,
rcut=0.4,
sel=[3],
)
ft = InvarFitting(
Expand Down Expand Up @@ -122,8 +122,8 @@ def setUp(self, mock_loadtxt):
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 0.4
self.rcut_smth = 2.2
self.rcut_smth = 0.4
self.rcut = 2.2

file_path = "dummy_path"
mock_loadtxt.return_value = np.array(
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_pairwise(self, mock_loadtxt):
nlist = torch.tensor([[[1], [-1]]], device=env.DEVICE)

ds = DescrptSeA(
rcut=0.3,
rcut_smth=0.4,
rcut_smth=0.3,
rcut=0.4,
sel=[3],
).to(env.DEVICE)
ft = InvarFitting(
Expand Down

0 comments on commit b0171ce

Please sign in to comment.