Skip to content

Commit

Permalink
[Bugfix][Relax] Set purity=false for LazySetOutput (#17119)
Browse files Browse the repository at this point in the history
The `relax.transform.LazySetOutput` transformation updates a Relax function
to produce output from a `fset_output` callback.  In the initial
implementation, the `fset_output` was marked as a pure function, which
allowed it to be erroneously removed from a function.

This commit updates the `relax::FuncStructInfo` used to annotate
`fset_output`, marking it as an impure function.
  • Loading branch information
Lunderberg authored Jun 27, 2024
1 parent 73cad19 commit 3c6ca5d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
3 changes: 2 additions & 1 deletion src/relax/transform/lazy_transform_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class LazyOutputMutator : public ExprMutator {

Var fset_output("fset_output",
FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()},
TupleStructInfo(Array<StructInfo>{})));
TupleStructInfo(Array<StructInfo>{}), /* purity = */ false));
plan_ = FunctionPlan{std::move(output_lookup), fset_output};

std::optional<int64_t> num_input_params = GetNumInputParams(func);
Expand Down Expand Up @@ -189,6 +189,7 @@ class LazyOutputMutator : public ExprMutator {
auto write_ptr = node.CopyOnWrite();
write_ptr->params = new_params;
write_ptr->body = new_body;
write_ptr->is_pure = false;
}
if (num_input_params.has_value()) {
node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value() + 1));
Expand Down
24 changes: 12 additions & 12 deletions tests/python/relax/test_transform_lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl

@I.ir_module
class Expected:
@R.function
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
fset_output(R.prim_value(1), C)
Expand Down Expand Up @@ -1036,11 +1036,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl

@I.ir_module
class Expected:
@R.function
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
fset_output(R.prim_value(1), B)
C = R.multiply(A, R.const(2, "float32"))
Expand Down Expand Up @@ -1070,10 +1070,10 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl

@I.ir_module
class Expected:
@R.function
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
B: R.Tensor([16, 16], "float32"),
):
R.func_attr({"num_input": 2})
Expand Down Expand Up @@ -1105,11 +1105,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl

@I.ir_module
class Expected:
@R.function
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
Expand Down Expand Up @@ -1140,11 +1140,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl

@I.ir_module
class Expected:
@R.function
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
fset_output(R.prim_value(0), C)
Expand All @@ -1171,11 +1171,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl

@I.ir_module
class Expected:
@R.function
@R.function(pure=False)
def transform_params(
A: R.Tensor([16, 16], "float32"),
B: R.Tensor([16, 16], "float32"),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False),
):
C = R.multiply(A, R.const(2, "float32"))
D = R.add(C, B)
Expand Down

0 comments on commit 3c6ca5d

Please sign in to comment.