From d3ca9d72bdea229f106e1aca50a115f2a0cdbcde Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Mar 2024 01:07:37 -0500 Subject: [PATCH] pt: avoid D2H in se_e2_a (#3424) sec is used as a slice index, so it should not stored on the GPU, otherwise, D2H will happen to create the tensor with the shape. Before: ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/f5a520d8-ed83-4520-aed0-d8fed547c293) After: ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/5548632b-3099-4fe2-ab53-5c570abd714a) Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 3a18f150a4..c4b2c772f8 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -342,9 +342,8 @@ def __init__( self.reinit_exclude(exclude_types) self.sel = sel - self.sec = torch.tensor( - np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE - ) + # should be on CPU to avoid D2H, as it is used as slice index + self.sec = [0, *np.cumsum(self.sel).tolist()] self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4