From 3116e028d28e08a48b21d667af8dd051ecf09c4b Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 27 Oct 2023 11:50:52 +0800 Subject: [PATCH] pnnx drop sdap scale=None for compatiblity with old torch (#5107) --- .../pass_level2/F_scaled_dot_product_attention.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index 8dcfafaf12b..ecfc6d8acaf 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -65,6 +65,17 @@ pnnx.Output output 1 0 out { return "F.scaled_dot_product_attention"; } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + GraphRewriterPass::write(op, captured_params, captured_attrs); + + if (captured_params.at("scale").type == 0) + { + // drop scale=None for compatiblity with old torch + op->params.erase("scale"); + } + } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)