-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: Bugfix in Linear-to-AddMM Fusion lowering pass
- Fix 2 bugs in linear-to-addmm lowering pass: - Lowering pass did not explore nested sub-blocks of a node, of the sort contained in `prim::If` when `bias=None` - Lowering pass did not insert fused linear code inside sub-blocks of `prim::If` even when the original function call occurred within such a block - The latter causes issues when the control-flow switches between two versions of `aten::linear`, only one of which is a valid operation. Thus, evaluating both branches can cause compilation to crash, as invalid Tensor shapes can be encountered - Update implementation to run recursively through all nested blocks within all nodes - Update implementation to remove the use of `RegisterRewritePattern` paradigm for Tensor biases, as the rewrite does not always place the subgraph in the desired location - Add regression test cases to isolate both bugs
- Loading branch information
Showing
2 changed files
with
119 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters