diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 1e20a1abf..e3772b9b8 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -492,6 +492,8 @@ def forward( if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) if self.custom_radial: + assert self.radial_module is not None + assert self.cutoff_module is not None rr = torch.linalg.norm(diff, dim=-1) g2 = self.radial_module(rr) * self.cutoff_module(rr).unsqueeze(-1) g2 = g2.view(nframes, nloc, nnei, -1) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 346fea790..5e2b276ad 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -566,6 +566,8 @@ def forward( if not self.custom_radial: ss = rr[:, :, :1] else: + assert self.radial_module is not None + assert self.cutoff_module is not None dist = torch.linalg.norm(diff, dim=-1) ss = self.radial_module(dist) * self.cutoff_module(dist).unsqueeze(-1) ss = ss.view(nframes * nloc, nnei, -1)