diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 8de386dfccbb..c728375b7965 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,68 +19,82 @@ import tvm.ir from . import _ffi_api + @tvm._ffi.register_object("relax.FunctionPass") class FunctionPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.Function in a module. A function pass class should be created through `function_pass`. """ -def FMARewrite() -> tvm.transform.Pass: + +def FMARewrite() -> tvm.ir.transform.Pass: """Perform fused multiply add rewriting in dataflow blocks. Returns ------- - ret: tvm.transform.Pass + ret: tvm.ir.transform.Pass """ return _ffi_api.FMARewrite() -def ToNonDataflow() -> tvm.transform.Pass: +def ToNonDataflow() -> tvm.ir.transform.Pass: """Transform all dataflow structure to non-dataflow version. Returns ------- - ret: tvm.transform.Pass + ret: tvm.ir.transform.Pass """ return _ffi_api.ToNonDataflow() -def CallDPSRewrite() -> tvm.transform.Pass: +def CallDPSRewrite() -> tvm.ir.transform.Pass: """Perform explicit tensor allocation for call_dps. Returns ------- - ret: tvm.transform.Pass + ret: tvm.ir.transform.Pass """ return _ffi_api.CallDPSRewrite() -def VMMemoryLower() -> tvm.transform.Pass: +def VMMemoryLower() -> tvm.ir.transform.Pass: """Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. Returns ------- - ret: tvm.transform.Pass + ret: tvm.ir.transform.Pass """ return _ffi_api.VMMemoryLower() -def VMShapeLower() -> tvm.transform.Pass: - """Lower the shape expressions in relax to VM shape heap manipulations and generate related +def VMShapeLower() -> tvm.ir.transform.Pass: + """Lower the shape expressions in relax to VM shape heap manipulations and generate related TIR functions to do shape calculations. Returns ------- - ret: tvm.transform.Pass + ret: tvm.ir.transform.Pass """ return _ffi_api.VMShapeLower() -def ToANF() -> tvm.transform.Pass: +def ToANF() -> tvm.ir.transform.Pass: """Transforming Relax IR to A-normal form. Returns ------- - ret: tvm.transform.Pass + ret: tvm.ir.transform.Pass """ return _ffi_api.ToANF() + + +def ResolveGlobals() -> tvm.ir.transform.Pass: + """Resolve global variables using string equality. This ensures all GlobalVars in the IR refer + to the correct GlobalVar of the input IRModule. An error is reported if any GlobalVar cannot be + resolved. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ResolveGlobals() diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 908af081c958..c734efca3007 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -29,7 +29,7 @@ from synr import ast, Transformer, to_ast import tvm -from tvm import IRModule +from tvm import IRModule, relax from tvm._ffi.base import TVMError from tvm.ir import GlobalVar from tvm.ir.function import BaseFunc @@ -1381,5 +1381,9 @@ def ir_module(input_module: type) -> IRModule: func_dict = { name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) } - return IRModule(func_dict) + mod = IRModule(func_dict) + mod = relax.transform.ResolveGlobals()(mod) + # FIXME(@altanh): where is the source map? + return mod + raise TypeError("Only class definitions are supported.") diff --git a/python/tvm/script/relax/parser.py b/python/tvm/script/relax/parser.py index 634083014d4c..7b6e65fd602c 100644 --- a/python/tvm/script/relax/parser.py +++ b/python/tvm/script/relax/parser.py @@ -970,9 +970,12 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr: var_name = expr.id.name if _is_registered(var_name, op_set=self._registered_ops): return relay.op.get(var_name) - if var_name not in self.scope: - self.report_error("undefined variable", expr.span) - return self.scope[var_name] + if var_name in self.scope: + return self.scope[var_name] + # NOTE: this is a "hack" to get around Python eagerly parsing class method decorators + # first (meaning we need to resolve them after the functions are parsed). These + # GlobalVars need to be resolved using string equality only. + return relay.GlobalVar(var_name) elif isinstance(expr, ast.Constant): # FIXME(@altanh): use internal representation that doesn't have precision limits here diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 8b29c086d7ca..6095bd643e3a 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -497,13 +497,18 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& } Doc RelaxScriptPrinter::PrintVarAnnotation(const relax::Var& var) { + // TODO(@altanh): we should consider moving annotation into binding Doc doc; - if (var->type_annotation.defined()) { + Type annotation = var->checked_type_; + if (!annotation.defined()) { + annotation = var->type_annotation.value_or(Type()); + } + if (annotation.defined()) { doc << ": "; - if (const relax::DynTensorTypeNode* tty = var->type_annotation.as()) { + if (const relax::DynTensorTypeNode* tty = annotation.as()) { doc << PrintTensorAnnotation(GetRef(tty), var->shape_); } else { - doc << Print(var->type_annotation); + doc << Print(annotation); } } return doc; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index f9f0d9262a1c..5b95e191247c 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -203,10 +203,10 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { private: /*! - * \brief Memoization map for expressions using Id for equality of variables. - */ + * \brief Memoization map for expressions using Id for equality of variables. + */ class ExprMemo { - public: + public: Optional Get(const Expr& expr) { if (const VarNode* var = expr.as()) { auto it = var_memo_.find(var->vid); @@ -230,7 +230,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { } } - private: + private: std::unordered_map var_memo_; std::unordered_map expr_memo_; }; @@ -370,7 +370,9 @@ Var BlockBuilderNode::Emit(const Expr& expr, bool is_dataflow, std::string name_ Var BlockBuilderNode::Emit(const VarBinding& binding) { BlockFrame* cur_frame = CurrentFrame(); if (cur_frame->is_dataflow) { - ICHECK(binding->var.as()); + ICHECK(binding->var.as()) + << "Emit can only be used for local bindings in a dataflow block, use EmitOutput for " + "output bindings instead"; } cur_frame->bindings.push_back(binding); binding_table_[binding->var->vid] = binding->value; @@ -408,9 +410,11 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array& p Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) { BlockFrame* cur_frame = CurrentFrame(); - if (cur_frame->is_dataflow && binding->var.defined()) { - ICHECK(!binding->var.as()) - << "cannot bind DataflowVar outside dataflow block."; + if (binding->var.defined()) { + ICHECK(!cur_frame->is_dataflow || binding->var.as()) + << "EmitMatchShape can only be used for local bindings in a dataflow block."; + ICHECK(cur_frame->is_dataflow || !binding->var.as()) + << "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint(); } cur_frame->bindings.push_back(binding); // TODO(@altanh, @yuchen): what value should we bind? Consider @@ -511,13 +515,9 @@ BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() { return &block_stack_.top(); } -NameTable* BlockBuilderNode::name_table() { - return name_table_.get(); -} +NameTable* BlockBuilderNode::name_table() { return name_table_.get(); } -BlockBuilder BlockBuilder::Create() { - return BlockBuilder(make_object()); -} +BlockBuilder BlockBuilder::Create() { return BlockBuilder(make_object()); } TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilder::Create); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index bb1a1c58d96c..12ff2413b8fe 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -354,11 +354,20 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) { Expr new_value = this->VisitExpr(binding->value); Var new_var = this->VisitVarDef(binding->var); - if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { - // no-op if there is no change - builder_->Emit(GetRef(binding)); - return; - } + auto emit = [this](VarBinding b) { + if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as()) { + this->builder_->EmitOutput(b); + } else { + this->builder_->Emit(b); + } + }; + + // FIXME(@altanh): try to clean up all the fast paths and ty/shape infer, it's getting unwieldy + // if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + // // no-op if there is no change + // emit(GetRef(binding)); + // return; + // } { Var temp = WithShapeAndType(new_var, new_value->shape_, new_value->checked_type_); @@ -368,11 +377,7 @@ void ExprMutator::VisitBinding_(const VarBindingNode* binding) { } } - if (builder_->CurrentBlockIsDataFlow() && !new_var.as()) { - builder_->EmitOutput(VarBinding(new_var, new_value)); - } else { - builder_->Emit(VarBinding(new_var, new_value)); - } + emit(VarBinding(new_var, new_value)); } void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { @@ -387,8 +392,8 @@ void ExprMutator::VisitBinding_(const MatchShapeNode* binding) { if (new_value->checked_type_.defined() && new_value->checked_type_.as()) { new_shape = new_pattern; } - Var temp = - WithShapeAndType(this->VisitVarDef(binding->var), new_shape, new_value->checked_type_); + new_var = this->VisitVarDef(binding->var); + Var temp = WithShapeAndType(new_var, new_shape, new_value->checked_type_); if (!temp.same_as(new_var)) { new_var = temp; this->var_remap_[binding->var->vid] = new_var; diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index 0684d189f4d6..a3f5bc3a49ba 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -81,7 +81,8 @@ Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { auto* t1 = rhs_type.as(); if (!t0 || !t1) { diag_ctx.EmitFatal(Diagnostic::Error(call->span) - << "Both lhs and rhs should be DynTensor for broadcasting"); + << "Both lhs and rhs should be DynTensor for broadcasting, but got " + << lhs_type->GetTypeKey() << " and " << rhs_type->GetTypeKey()); } DataType output_dtype; diff --git a/src/relax/transform/resolve_globals.cc b/src/relax/transform/resolve_globals.cc new file mode 100644 index 000000000000..2851a97d5b88 --- /dev/null +++ b/src/relax/transform/resolve_globals.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/resolve_globals.cc + * \brief Resolve GlobalVars using string equality. + */ +#include +#include + +namespace tvm { +namespace relax { + +class GlobalVarResolver : public ExprMutator { + public: + GlobalVarResolver(IRModule mod, DiagnosticContext diag_ctx) : mod_(mod), diag_ctx_(diag_ctx) {} + + Expr VisitExpr_(const GlobalVarNode* gvar) { + if (!mod_->ContainGlobalVar(gvar->name_hint)) { + diag_ctx_.Emit(Diagnostic::Error(gvar->span) + << "undefined variable/global \"" << gvar->name_hint << "\""); + return GetRef(gvar); + } + return mod_->GetGlobalVar(gvar->name_hint); + } + + private: + /*! \brief the IRModule used for GlobalVar lookup. */ + IRModule mod_; + DiagnosticContext diag_ctx_; +}; + +namespace transform { + +Pass ResolveGlobals() { + runtime::TypedPackedFunc pass_func = + [](Function f, IRModule m, PassContext pc) { + // TODO(@altanh): make sure pc always has diag_ctx? + GlobalVarResolver resolver(m, pc->diag_ctx.value()); + return Downcast(resolver.VisitExpr(f)); + }; + return CreateFunctionPass(pass_func, 0, "ResolveGlobals", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ResolveGlobals").set_body_typed(ResolveGlobals); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 76999e329fe7..4f715efa8912 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -496,45 +496,35 @@ def f(x: Tensor): def test_class_irmodule(): - # FIXME(@altanh): Python class method decorators are executed eagerly before the class - # decorator, which means each function is parsed in isolation. This means we cannot resolve - # global variables at parsing time (or indeed any undefined identifier), so we either need to - # 1. defer parsing in the function decorators (so that the ir_module decorator can populate - # global variables first), although this means non-IRModule uses of the function decorators - # will no longer return Function/PrimFunc but some kind of wrapper type. This could cause - # problems if we pass them directly to things that expect Function/PrimFuncs. - # 2. parse every undefined identifier to a placeholder node (e.g. "UndefinedVar"), and run an - # IRModule -> IRModule pass that tries to resolve identifiers. - src = """@tvm.script.ir_module -class MyModule: - @T.prim_func - def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - C = T.match_buffer(c, (128, 128)) - - for i, j, k in T.grid(128, 128, 128): - with T.block(): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] += A[vi, vk] * B[vj, vk] + @tvm.script.ir_module + class MyModule: + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - @R.function - def f(x: Tensor[(n, n), _]) -> Tensor: - return g(x) + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] - @R.function - def g(y: Tensor[(n, n), _]) -> Tensor: - return relax.call_dps((n, n), my_matmul, (y, y)) + @R.function + def f(x: Tensor[(n, n), _]) -> Tensor: + return g(x) - @R.function - def h(x, y, z): - _ = my_matmul(x, y, z) - return z -""" + @R.function + def g(y: Tensor[(n, n), _]) -> Tensor: + return relax.call_dps((n, n), my_matmul, (y, y)) + + @R.function + def h(x, y, z): + _ = my_matmul(x, y, z) + return z - my_module = tvm.script.relax.parser.from_source(src) + my_module = MyModule assert isinstance(my_module, tvm.IRModule) var_f = my_module.get_global_var("f") diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index d8c630ffe898..e7f9d5ae49ff 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -181,35 +181,33 @@ def foo(x: Tensor): def test_class_irmodule(): - # FIXME(@altanh): see comment in test_parser.py - src = """@tvm.script.ir_module -class MyModule: - @T.prim_func - def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - C = T.match_buffer(c, (128, 128)) - - for i, j, k in T.grid(128, 128, 128): - with T.block(): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] += A[vi, vk] * B[vj, vk] + @tvm.script.ir_module + class MyModule: + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - @R.function - def f(x: Tensor[(n, n), _]) -> Tensor: - return g(x) + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] - @R.function - def g(y: Tensor[(n, n), _]) -> Tensor: - return relax.call_dps((n, n), my_matmul, (y, y)) + @R.function + def f(x: Tensor[(n, n), _]) -> Tensor: + return g(x) - @R.function - def h(x, y, z): - _ = my_matmul(x, y, z) - return z -""" + @R.function + def g(y: Tensor[(n, n), _]) -> Tensor: + return relax.call_dps((n, n), my_matmul, (y, y)) + + @R.function + def h(x, y, z): + _ = my_matmul(x, y, z) + return z - my_module = tvm.script.relax.parser.from_source(src) + my_module = MyModule check_roundtrip(my_module) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index cafca7246334..abd6fdb0f99e 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -20,6 +20,7 @@ from tvm import relax from tvm import tir from tvm.ir import structural_equal +from tvm.ir.base import assert_structural_equal from tvm.ir.module import IRModule import tvm.script @@ -211,30 +212,30 @@ def foo(x: Tensor[_, "float32"]) -> Shape: def test_vm_shape_lowering_func_param_with_shape(): - src = """@tvm.script.ir_module -class InputModule: - @T.prim_func - def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: - T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int32") - n = T.var("int32") - k = T.var("int32") - A = T.match_buffer(x, (m,n)) - B = T.match_buffer(y, (n,k)) - C = T.match_buffer(z, (m,k)) - - for i, j, k in T.grid(m, k, n): - with T.block("matmul"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - @R.function - def foo(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: - gv0 = R.call_dps((m, k), tir_matmul, (x, w)) - return gv0 -""" - mod = tvm.script.relax.parser.from_source(src) + @tvm.script.ir_module + class InputModule: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (n,k)) + C = T.match_buffer(z, (m,k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + @R.function + def foo(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: + gv0 = R.call_dps((m, k), tir_matmul, (x, w)) + return gv0 + + mod = InputModule # after vm shape lowering new_mod = relax.transform.VMShapeLower()(mod) @@ -289,9 +290,24 @@ def f(x: Tensor[_, "float32"]): gv2 = relax.add(gv, gv1) return (gv, gv2) - # TODO(@altanh): fix this once type inference works properly...? - assert R.parser.astext(new_mod) == R.parser.astext(TestToANFExpected) + assert_structural_equal(new_mod, TestToANFExpected, map_free_vars=True) + + +def test_to_anf_no_op(): + @tvm.script.ir_module + class TestANFNoOp: + @R.function + def foo(x: Tensor[(m, n), "float32"]): + with relax.dataflow(): + lv0 = relax.call_dps((m, n), "test.op.identity", (x,)) + gv0 = relax.call_dps((m, n), "test.op.identity", (lv0,)) + relax.output(gv0) + return gv0 + + mod = TestANFNoOp + mod_post = relax.transform.ToANF()(mod) + assert_structural_equal(mod, mod_post) if __name__ == "__main__": @@ -302,3 +318,4 @@ def f(x: Tensor[_, "float32"]): test_vm_shape_lowering() test_vm_shape_lowering_func_param_with_shape() test_to_anf() + test_to_anf_no_op() diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index 052cb32744a4..76e0f838a0a4 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -247,42 +247,40 @@ def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): def test_vm_compile_stage1(): - # FIXME(@altanh): see comment in test_parser.py - src = """@tvm.script.ir_module -class TestVMCompileStage1: - @T.prim_func - def shape_func0(heap: T.handle) -> None: - # function attr dict - T.func_attr({"global_symbol": "shape_func0"}) - H = T.match_buffer( - heap, - [T.int64(4)], - dtype="int64", - elem_offset=T.int64(0), - align=128, - offset_factor=1, - ) - # body - T.store( - H.data, T.int64(2), (T.load("int64", H.data, T.int64(0)) * T.int64(2)), True - ) - T.store( - H.data, T.int64(3), (T.load("int64", H.data, T.int64(1)) * T.int64(3)), True - ) + @tvm.script.ir_module + class TestVMCompileStage1: + @T.prim_func + def shape_func0(heap: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "shape_func0"}) + H = T.match_buffer( + heap, + [T.int64(4)], + dtype="int64", + elem_offset=T.int64(0), + align=128, + offset_factor=1, + ) + # body + T.store( + H.data, T.int64(2), (T.load("int64", H.data, T.int64(0)) * T.int64(2)), True + ) + T.store( + H.data, T.int64(3), (T.load("int64", H.data, T.int64(1)) * T.int64(3)), True + ) - @R.function - def foo(x: Tensor[_, "float32"]) -> Shape: - shape_heap: Tensor[(4,), "int64"] = R.call_packed( - "vm.builtin.alloc_shape_heap", (4,) - ) - gv0 = R.call_packed("vm.builtin.shape_of", x) - gv1 = R.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1)) - gv2 = shape_func0(shape_heap) - gv3 = R.call_packed("vm.builtin.load_shape", shape_heap, (2, 3)) - return gv3 -""" - - mod = R.parser.from_source(src) + @R.function + def foo(x: Tensor[_, "float32"]) -> Shape: + shape_heap: Tensor[(4,), "int64"] = relax.call_packed( + "vm.builtin.alloc_shape_heap", (4,) + ) + gv0 = relax.call_packed("vm.builtin.shape_of", x) + gv1 = relax.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1)) + gv2 = shape_func0(shape_heap) + gv3 = relax.call_packed("vm.builtin.load_shape", shape_heap, (2, 3)) + return gv3 + + mod = TestVMCompileStage1 code = R.parser.astext(mod) target = tvm.target.Target("llvm") target_host = tvm.target.Target("llvm") @@ -363,32 +361,32 @@ def foo(x: Tensor[_, "float32"]) -> Tensor: np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), res.asnumpy()) def test_vm_compile_e2e_func_param_with_shape(): - src = """@tvm.script.ir_module -class TestVMCompileE2E2: - @T.prim_func - def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: - T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int32") - n = T.var("int32") - k = T.var("int32") - A = T.match_buffer(x, (m,n)) - B = T.match_buffer(y, (n,k)) - C = T.match_buffer(z, (m,k)) - - for i, j, k in T.grid(m, k, n): - with T.block("matmul"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - - @R.function - def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: - gv0 = R.call_dps((m, k), tir_matmul, (x, w)) - return gv0 -""" - - mod = tvm.script.relax.parser.from_source(src) + @tvm.script.ir_module + class TestVMCompileE2E2: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m,n)) + B = T.match_buffer(y, (n,k)) + C = T.match_buffer(z, (m,k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def func(x:Tensor[(m, n), "float32"], w:Tensor[(n, k), "float32"]) -> Tensor: + gv0 = R.call_dps((m, k), tir_matmul, (x, w)) + return gv0 + + + mod = TestVMCompileE2E2 target = tvm.target.Target("llvm") target_host = tvm.target.Target("llvm")