Skip to content

Commit

Permalink
pt: avoid D2H in se_e2_a (#3424)
Browse files Browse the repository at this point in the history
sec is used as a slice index, so it should not stored on the GPU,
otherwise, D2H will happen to create the tensor with the shape.

Before:

![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/f5a520d8-ed83-4520-aed0-d8fed547c293)

After:

![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/5548632b-3099-4fe2-ab53-5c570abd714a)

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Mar 6, 2024
1 parent 278e6b8 commit d3ca9d7
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,8 @@ def __init__(
self.reinit_exclude(exclude_types)

self.sel = sel
self.sec = torch.tensor(
np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE
)
# should be on CPU to avoid D2H, as it is used as slice index
self.sec = [0, *np.cumsum(self.sel).tolist()]
self.split_sel = self.sel
self.nnei = sum(sel)
self.ndescrpt = self.nnei * 4
Expand Down

0 comments on commit d3ca9d7

Please sign in to comment.