Skip to content

Commit

Permalink
fix se_r consistency
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Feb 29, 2024
1 parent 665d716 commit 42d47f9
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def call(
gg = self.cal_g(tr, tt)
gg = np.mean(gg, axis=2)
# nf x nloc x ng x 1
xyz_scatter += gg
xyz_scatter += gg * (self.sel[tt] / self.nnei)

Check warning on line 279 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L279

Added line #L279 was not covered by tests

res_rescale = 1.0 / 10.0
res_rescale = 1.0 / 5.0

Check warning on line 281 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L281

Added line #L281 was not covered by tests
res = xyz_scatter * res_rescale
res = res.reshape(nf, nloc, -1).astype(GLOBAL_NP_FLOAT_PRECISION)
return res, None, None, None, ww
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def forward(
# nfnl x nt x ng
gg = ll.forward(ss)
gg = torch.mean(gg, dim=1).unsqueeze(1)
xyz_scatter += gg
xyz_scatter += gg * (self.sel[ii] / self.nnei)

Check warning on line 261 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L261

Added line #L261 was not covered by tests

res_rescale = 1.0 / 10.0
res_rescale = 1.0 / 5.0

Check warning on line 263 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L263

Added line #L263 was not covered by tests
result = xyz_scatter * res_rescale
result = result.view(-1, nloc, self.filter_neuron[-1])
return (
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def data(self) -> dict:
precision,
) = self.param
return {
"sel": [10, 10],
"sel": [9, 10],
"rcut_smth": 5.80,
"rcut": 6.00,
"neuron": [6, 12, 24],
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def data(self) -> dict:
precision,
) = self.param
return {
"sel": [10, 10],
"sel": [9, 10],
"rcut_smth": 5.80,
"rcut": 6.00,
"neuron": [6, 12, 24],
Expand Down

0 comments on commit 42d47f9

Please sign in to comment.