Skip to content

Commit

Permalink
[BugFix][TVMScript]fix var capturing order error (apache#13640)
Browse files Browse the repository at this point in the history
This PR try to fix the following bug:

```python
def test_var_capturing_order():
    b = 2

    @T.prim_func
    def test_case():
        k: T.int32 = b


if __name__ == "__main__":
    b = 1
```

In the prim func `test_case`, the vaule of b should be 2, rather than 1. The parser wrongly uses global vars to shadow the value of nonlocal vars, which should be reversed.

Co-authored-by: lightzhan-intellif <zhan.liang@intellif.com>
  • Loading branch information
2 people authored and fzi-peccia committed Mar 27, 2023
1 parent c798ed1 commit 8798c93
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def inspect_function_capture(func: Callable) -> Dict[str, Any]:
The function variables map with non-local or global variables.
"""
captured = {
**inspect.getclosurevars(func).nonlocals,
**func.__globals__, # type: ignore
**inspect.getclosurevars(func).nonlocals,
}
return captured

Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tvmscript_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,24 @@ def func_ref():
tvm.ir.assert_structural_equal(test_case, func_ref)


def test_var_capturing_order():
b = 2

@T.prim_func
def test_case():
k: T.int32 = b

@T.prim_func
def func_ref():
k: T.int32 = 2
T.evaluate(0)

tvm.ir.assert_structural_equal(test_case, func_ref)


if __name__ == "__main__":
a = numpy.zeros((10, 10), dtype="int8")
test_multi_element_array_in_outmost_namespace()
test_different_dtype_assignment_to_var()
b = 1
test_var_capturing_order()

0 comments on commit 8798c93

Please sign in to comment.