From 93ac49b47c0c1acde808be3cc38b2de34fe8f3d7 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Apr 2024 04:02:31 -0400 Subject: [PATCH] fix(tf): fix foat32 for exclude_types in se_atten_v2 (#3682) Fix type issue in previous PR #3651. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit cac87152eb0010dba3246e24599e9fd8748a039a) --- deepmd/descriptor/se_atten.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index ced2d08f18..b9227916e4 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -673,7 +673,11 @@ def _pass_filter( ), ) self.recovered_switch *= tf.reshape( - tf.slice(tf.reshape(mask, [-1, 4]), [0, 0], [-1, 1]), + tf.slice( + tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]), + [0, 0], + [-1, 1], + ), [-1, natoms[0], self.sel_all_a[0]], ) else: