From 3c6ca5d92be2f147608df545073bbfd702d49be5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 27 Jun 2024 09:52:09 -0500 Subject: [PATCH] [Bugfix][Relax] Set purity=false for LazySetOutput (#17119) The `relax.transform.LazySetOutput` transformation updates a Relax function to produce output from a `fset_output` callback. In the initial implementation, the `fset_output` was marked as a pure function, which allowed it to be erroneously removed from a function. This commit updates the `relax::FuncStructInfo` used to annotate `fset_output`, marking it as an impure function. --- src/relax/transform/lazy_transform_params.cc | 3 ++- .../test_transform_lazy_transform_params.py | 24 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index fb401e1b6787..f55b93ff3d3a 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -149,7 +149,7 @@ class LazyOutputMutator : public ExprMutator { Var fset_output("fset_output", FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, - TupleStructInfo(Array{}))); + TupleStructInfo(Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; std::optional num_input_params = GetNumInputParams(func); @@ -189,6 +189,7 @@ class LazyOutputMutator : public ExprMutator { auto write_ptr = node.CopyOnWrite(); write_ptr->params = new_params; write_ptr->body = new_body; + write_ptr->is_pure = false; } if (num_input_params.has_value()) { node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value() + 1)); diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 040aea28909d..278ac825f7a7 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -1002,11 +1002,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) fset_output(R.prim_value(1), C) @@ -1036,11 +1036,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): fset_output(R.prim_value(1), B) C = R.multiply(A, R.const(2, "float32")) @@ -1070,10 +1070,10 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), B: R.Tensor([16, 16], "float32"), ): R.func_attr({"num_input": 2}) @@ -1105,11 +1105,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) D = R.add(C, B) @@ -1140,11 +1140,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) fset_output(R.prim_value(0), C) @@ -1171,11 +1171,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) D = R.add(C, B)