diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index dfecaacdf655..7d81fecedbca 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -406,13 +406,25 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: The doc AST Expr node. """ res = self.eval_expr(node.value) - if isinstance(res, Frame): + if res is None: + pass + elif isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() elif isinstance(res, PrimExpr): T.evaluate(res) elif isinstance(res, (int, bool)): T.evaluate(tvm.tir.const(res)) + elif isinstance(res, tvm.relay.Call) and not res.args: + # Using GlobalVar.__call__ with no arguments is ambiguous, as + # each IR has a different function Call representation. If + # this occurs, convert to the TIR representation. + T.evaluate(tvm.tir.call_tir(res.op)) + elif isinstance(res, str): + # Ignore docstrings + pass + else: + self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") @dispatch.register(token="tir", type_name="If") diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 2ea7d3ec6579..3e09887e0d0d 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3855,6 +3855,24 @@ def func(): return func +def subroutine_call_without_arguments(): + @I.ir_module + class mod: + @T.prim_func + def main(): + # Should be equivalent to the bare "mod.subroutine()", but + # that relies on `GlobalVar.__call__` returning the + # correct IR type. Previously, this instead returned a + # `relay.Call` object. + tir.call_tir(mod.subroutine) + + @T.prim_func + def subroutine(): + T.evaluate(0) + + return mod + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3929,6 +3947,7 @@ def func(): undefined_shape_in_decl_buffer, undefined_stride_in_decl_buffer, undefined_elem_offset_in_decl_buffer, + subroutine_call_without_arguments, )