diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index dbfaf60fecfc..63c74db7e33e 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -257,6 +257,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); + this->VisitExprDepStructInfoField(binding->struct_info); this->VisitVarDef(binding->var); } @@ -690,16 +691,25 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { } void ExprMutator::VisitBinding_(const MatchCastNode* binding) { - Var new_var = this->VisitVarDef(binding->var); Expr new_value = this->VisitExpr(binding->value); + StructInfo new_struct_info = this->VisitExprDepStructInfoField(binding->struct_info); - // re-emit old binding if nothing changes - if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + Var new_var = this->VisitVarDef(binding->var); + + if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && + new_struct_info.same_as(binding->struct_info)) { + // re-emit old binding if nothing changes builder_->EmitNormalized(GetRef(binding)); - } else { - new_value = builder_->NormalizeArgument(new_value); - builder_->EmitNormalized(MatchCast(new_var, new_value, binding->struct_info, binding->span)); + return; } + + new_value = builder_->NormalizeArgument(new_value); + new_var = WithStructInfo(new_var, new_struct_info); + + var_remap_[binding->var->vid] = new_var; + var_remap_[new_var->vid] = new_var; + + builder_->EmitNormalized(MatchCast(new_var, new_value, new_struct_info, binding->span)); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py index 82798c56dfff..18246d224b65 100644 --- a/tests/python/relax/test_bind_symbolic_vars.py +++ b/tests/python/relax/test_bind_symbolic_vars.py @@ -286,5 +286,27 @@ def expected(A: R.Tensor(["M", 32])): tvm.ir.assert_structural_equal(expected, after) +def test_bind_inside_match_cast(): + """Symbolic variables may occur within R.match_cast""" + + @R.function(private=True) + def before(A: R.Tensor(["M", "N"]), B: R.Tensor(ndim=2)): + M = T.int64() + N = T.int64() + C = R.match_cast(B, R.Tensor([M, N])) + D = R.add(A, C) + return D + + @R.function(private=True) + def expected(A: R.Tensor(["M", 32]), B: R.Tensor(ndim=2)): + M = T.int64() + C = R.match_cast(B, R.Tensor([M, 32])) + D = R.add(A, C) + return D + + after = before.bind_symbolic_vars({"N": 32}) + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main()