Skip to content

Commit

Permalink
pnnx drop sdap scale=None for compatiblity with old torch (#5107)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 27, 2023
1 parent 14e14a9 commit 3116e02
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ pnnx.Output output 1 0 out
{
return "F.scaled_dot_product_attention";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
GraphRewriterPass::write(op, captured_params, captured_attrs);

if (captured_params.at("scale").type == 0)
{
// drop scale=None for compatiblity with old torch
op->params.erase("scale");
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)
Expand Down

0 comments on commit 3116e02

Please sign in to comment.