From 9856f7a66a1de00cf0931395aee9afbbf0ad7bd4 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Thu, 4 Nov 2021 11:05:42 -0700 Subject: [PATCH] VM compiler refactor (#25) * Return Instruction::Arg for each CodeGenLLVM::VisitExpr_. * Change VMCompiler to be an Object from ModuleNode. * Introduce intrinsics and attrs. * Generic handling of attribute codegen. * Do to-non-dataflow transform in call_dps_rewrite. * Back to special attr handling. * Address comments. * Standalone to_non_dataflow pass; more tests. * Rename decode/make shape to store/load shape. * Update. * Fix namespace, add comments. * rebase * Rename files. * nit --- include/tvm/relax/attrs/memory.h | 14 +- include/tvm/relax/attrs/shape.h | 44 +++ python/tvm/relax/transform/transform.py | 15 +- python/tvm/relax/vm.py | 4 +- src/relax/backend/vm/codegen_vm.cc | 336 ++++++++++++++++++ .../compiler.h => backend/vm/codegen_vm.h} | 26 +- .../vm/vm_memory_lower.cc} | 45 ++- .../vm/vm_shape_lower.cc} | 47 +-- src/relax/op/op.cc | 62 ++++ src/relax/transform/call_dps_rewrite.cc | 9 +- src/relax/vm/builtin.cc | 21 +- src/relax/vm/compiler.cc | 291 --------------- tests/python/relax/test_transform.py | 79 ++-- tests/python/relax/test_vm.py | 49 +-- 14 files changed, 596 insertions(+), 446 deletions(-) create mode 100644 include/tvm/relax/attrs/shape.h create mode 100644 src/relax/backend/vm/codegen_vm.cc rename src/relax/{vm/compiler.h => backend/vm/codegen_vm.h} (74%) rename src/relax/{transform/memory_rewrite.cc => backend/vm/vm_memory_lower.cc} (72%) rename src/relax/{transform/shape_lower.cc => backend/vm/vm_shape_lower.cc} (83%) delete mode 100644 src/relax/vm/compiler.cc diff --git a/include/tvm/relax/attrs/memory.h b/include/tvm/relax/attrs/memory.h index 91988906a214..1e5da74cc32f 100644 --- a/include/tvm/relax/attrs/memory.h +++ b/include/tvm/relax/attrs/memory.h @@ -29,29 +29,31 @@ namespace tvm { namespace relax { /*! - * \brief Options for allocating storage. + * \brief Attributes for allocating storage. */ struct AllocStorageAttrs : public tvm::AttrsNode { - DataType dtype; - int device_id; int device_type; + DataType dtype; TVM_DECLARE_ATTRS(AllocStorageAttrs, "relax.attrs.AllocStorageAttrs") { + TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); TVM_ATTR_FIELD(dtype) .describe("The dtype of the tensor to allocate.") .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); - TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); } }; /*! - * \brief Options for allocating tensors. + * \brief Attributes for allocating tensors. */ struct AllocTensorAttrs : public tvm::AttrsNode { + int offset; DataType dtype; TVM_DECLARE_ATTRS(AllocTensorAttrs, "relax.attrs.AllocTensorAttrs") { + TVM_ATTR_FIELD(offset) + .describe("Storage offset to allocate the tensor.") + .set_default(0); TVM_ATTR_FIELD(dtype) .describe("The dtype of the tensor to allocate.") .set_default(DataType::Float(32, 1)); diff --git a/include/tvm/relax/attrs/shape.h b/include/tvm/relax/attrs/shape.h new file mode 100644 index 000000000000..9c4aaad24b28 --- /dev/null +++ b/include/tvm/relax/attrs/shape.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/shape.h + * \brief Attributes for shape operators. + */ +#ifndef TVM_RELAX_ATTRS_SHAPE_H_ +#define TVM_RELAX_ATTRS_SHAPE_H_ + +#include + +namespace tvm { +namespace relax { +/*! + * \brief Attributes for decoding/making shape to/from VM heap. + */ +struct ShapeHeapAttrs : public tvm::AttrsNode { + Array indices; + + TVM_DECLARE_ATTRS(ShapeHeapAttrs, "relax.attrs.ShapeHeapAttrs") { + TVM_ATTR_FIELD(indices).describe("The indices of the heap to store/load the shape to/from."); + } +}; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_ATTRS_SHAPE_H_ diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 12dccd26e6b2..aedda8157b54 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -30,6 +30,7 @@ def fma_rewrite(expr): """ return _ffi_api.fma_rewrite(expr) + def to_non_dataflow(mod: IRModule) -> IRModule: """Transform all dataflow structure to non-dataflow version. @@ -42,7 +43,7 @@ def to_non_dataflow(mod: IRModule) -> IRModule: def call_dps_rewrite(mod: IRModule) -> IRModule: - """Perform explicit memory allocation for call_dps. + """Perform explicit tensor allocation for call_dps. Parameters ---------- @@ -52,23 +53,23 @@ def call_dps_rewrite(mod: IRModule) -> IRModule: return _ffi_api.call_dps_rewrite(mod) -def memory_lower(mod: IRModule) -> IRModule: - """Perform memory lowering. Lower the relax.builtin.alloc_tensor op to VM builtin functions. +def vm_memory_lower(mod: IRModule) -> IRModule: + """Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to VM intrinsics. Parameters ---------- mod : tvm.IRModule The input module. """ - return _ffi_api.memory_lower(mod) + return _ffi_api.vm_memory_lower(mod) -def shape_lower(mod: IRModule) -> IRModule: - """Lower the shape expression in relax to shape heap and TIR functions. +def vm_shape_lower(mod: IRModule) -> IRModule: + """Lower the shape expression in relax to VM shape heap and TIR functions. Parameters ---------- mod : tvm.IRModule The input module. """ - return _ffi_api.shape_lower(mod) + return _ffi_api.vm_shape_lower(mod) diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 8e7258b3b1ed..a094ab7a64d4 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -167,7 +167,7 @@ def build(mod: tvm.IRModule, """ new_mod = transform.to_non_dataflow(mod) new_mod = transform.call_dps_rewrite(new_mod) - new_mod = transform.memory_lower(new_mod) - new_mod = transform.shape_lower(new_mod) + new_mod = transform.vm_memory_lower(new_mod) + new_mod = transform.vm_shape_lower(new_mod) ex, lib = _ffi_api.VMBuild(new_mod, target, target_host) return ex, lib diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc new file mode 100644 index 000000000000..a40c749a4072 --- /dev/null +++ b/src/relax/backend/vm/codegen_vm.cc @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/vm/codegen_vm.cc + * \brief A compiler to compile an IRModule to VM executable. + */ + +#include "codegen_vm.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { +namespace relax_vm { + +using namespace relax; + +/*! + * \brief A class to generate VM executable for Relax functions. + */ +class CodeGenVM : public ExprFunctor { + public: + explicit CodeGenVM(ExecBuilderNode* builder) { + builder_ = GetRef(builder); + } + + protected: + size_t NewRegister() { return registers_num_++; } + + // TODO(@yuchen): add visitors for IfNode when goto and if instructions are introduced to relax vm. + + // TODO(@yuchen): when we support closure, this visitor should return a register that + // contains the closure object. + Instruction::Arg VisitExpr_(const FunctionNode* func_node) { + if (func_node->name.defined()) { + builder_->EmitFunction(func_node->name.value()->name_hint, func_node->params.size()); + } else { + // TODO(@yuchen): handle local functions that capture local vars outside the func + // TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a + // function named "local_funcN" + // lift the local func to a global func and compile it normally + builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++), + func_node->params.size()); + } + for (Var param : func_node->params) { + Instruction::Arg reg = this->VisitExpr(param); + this->var_register_map_.insert({param, reg.data}); + } + Instruction::Arg ret = ExprFunctor::VisitExpr(func_node->body); + builder_->EmitRet(ret.data); + return ret; + } + + Instruction::Arg VisitExpr_(const SeqExprNode* op) { + for (auto block : op->blocks) { + for (Binding binding : block->bindings) { + ICHECK(binding->IsInstance()); + Expr value = Downcast(binding)->value; + Var var = Downcast(binding)->var; + Instruction::Arg reg = this->VisitExpr(value); + this->var_register_map_.insert({var, reg.data}); + } + } + + Instruction::Arg ret_reg = this->VisitExpr(op->body); + return ret_reg; + } + + Instruction::Arg VisitExpr_(const CallNode* op) { + if (op->op.as()) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + const Call& call = GetRef(op); + if (op->op == alloc_storage_op_) { + return EmitAllocStorage(call); + } else if (op->op == alloc_tensor_op_) { + return EmitAllocTensor(call); + } else if (op->op == store_shape_op_ || op->op == load_shape_op_) { + return EmitShape(call); + } else { + // every "normal" operator is lowered to a global var in the IR module. The Attrs for those ops + // are handled in a pass when lowering them to TIR. + LOG(FATAL) << "CodeGenVM cannot handle this intrinsic now:\n" << op->op; + } + } + String name; + if (auto* extern_func = op->op.as()) { + name = extern_func->global_symbol; + } else if (auto* gvar = op->op.as()) { + name = gvar->name_hint; + } else { + LOG(FATAL) << "CodeGenVM does not support calls to " << op->op->GetTypeKey(); + } + std::vector args; + for (auto arg : op->args) { + args.push_back(this->VisitExpr(arg)); + } + size_t arg_register = NewRegister(); + builder_->EmitCall(name, args, arg_register); + + return Instruction::Arg(Instruction::kRegister, arg_register); + } + + Instruction::Arg VisitExpr_(const VarNode* op) { + auto it = this->var_register_map_.find(GetRef(op)); + if (it != this->var_register_map_.end()) { + return Instruction::Arg(Instruction::kRegister, it->second); + } else { + return Instruction::Arg(Instruction::kRegister, NewRegister()); + } + } + + Instruction::Arg VisitExpr_(const ShapeExprNode* op) { + ShapeExpr sh = GetRef(op); + ICHECK(IsConstantShape(sh)) + << "should only use constant shape after shape lowering: " + << sh->values; + std::vector shape; + for (PrimExpr e : sh->values) { + shape.push_back(Downcast(e)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + return Instruction::Arg(Instruction::kConstIdx, index); + } + + Instruction::Arg EmitAllocStorage(const Call& call_node) { + // Handle args of the call + std::vector args; + args.push_back(Instruction::Arg(Instruction::kVMStateRegister)); + for (Expr arg: call_node->args) { + args.push_back(ConvertArg(arg)); + } + + // Handle attrs of the call + auto alloc_attrs = call_node->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be AllocStorageAttrs"; + int device_type = alloc_attrs->device_type; + args.push_back(Instruction::Arg(Instruction::kImmediate, device_type)); + DataType dtype = alloc_attrs->dtype; + TVMRetValue data_type; + data_type = dtype; + Index index = this->builder_->EmitConstant(data_type); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + size_t arg_register = NewRegister(); + builder_->EmitCall("vm.builtin.alloc_storage", args, arg_register); + return Instruction::Arg(Instruction::kRegister, arg_register); + } + + Instruction::Arg EmitAllocTensor(const Call& call_node) { + // Handle args of the call + std::vector args; + for (Expr arg: call_node->args) { + args.push_back(ConvertArg(arg)); + } + + // Handle attrs of the call + auto alloc_attrs = call_node->attrs.as(); + ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs"; + int offset = alloc_attrs->offset; + args.push_back(Instruction::Arg(Instruction::kImmediate, offset)); + DataType dtype = alloc_attrs->dtype; + TVMRetValue data_type; + data_type = dtype; + Index index = this->builder_->EmitConstant(data_type); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + size_t arg_register = NewRegister(); + builder_->EmitCall("vm.builtin.alloc_tensor", args, arg_register); + return Instruction::Arg(Instruction::kRegister, arg_register); + } + + Instruction::Arg EmitShape(const Call& call_node) { + // Handle args of the call + std::vector args; + for (Expr arg: call_node->args) { + args.push_back(ConvertArg(arg)); + } + + // Handle attrs of the call + auto shape_attrs = call_node->attrs.as(); + ICHECK(shape_attrs != nullptr) << "must be ShapeHeapAttrs"; + std::vector indices_vec; + for (Integer ind : shape_attrs->indices) { + indices_vec.push_back(ind); + } + ShapeTuple indices = ShapeTuple(indices_vec); + TVMRetValue indices_const; + indices_const = indices; + Index index = builder_->EmitConstant(indices_const); + args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); + + size_t arg_register = NewRegister(); + if (call_node->op == store_shape_op_) { + builder_->EmitCall("vm.builtin.store_shape", args, arg_register); + } else if (call_node->op == load_shape_op_) { + builder_->EmitCall("vm.builtin.load_shape", args, arg_register); + } + return Instruction::Arg(Instruction::kRegister, arg_register); + } + + bool IsConstantShape(ShapeExpr shape) const { + for (PrimExpr e : shape->values) { + if (!e->IsInstance()) { + return false; + } + } + return true; + } + + Instruction::Arg ConvertArg(Expr arg) { + if (arg->IsInstance()) { + Var var = Downcast(arg); + auto reg = this->var_register_map_.find(Downcast(arg)); + ICHECK(reg != this->var_register_map_.end()) + << var->name_hint() << "(" << var << ")" << " not in the register map."; + return Instruction::Arg(Instruction::kRegister, reg->second); + } else if (arg->IsInstance()) { + ShapeExpr sh = Downcast(arg); + ICHECK(IsConstantShape(sh)) + << "should only use constant shape after shape lowering: " + << sh->values; + std::vector shape; + for (PrimExpr e : sh->values) { + shape.push_back(Downcast(e)->value); + } + auto shape_tuple = ShapeTuple(shape); + TVMRetValue shape_tuple_value; + shape_tuple_value = shape_tuple; + Index index = builder_->EmitConstant(shape_tuple_value); + return Instruction::Arg(Instruction::kConstIdx, index); + } else { + LOG(FATAL) << "CodeGenVM does not this argument type:\n" << arg->GetTypeKey(); + } + return Instruction::Arg(); + } + + std::vector ConvertArgs(const Call& call) { + std::vector ret; + for (size_t i = 0; i < call->args.size(); ++i) { + ret.push_back(ConvertArg(call->args[i])); + } + return ret; + } + + /*! \brief A counter for naming local functions. */ + int local_func_counter_ = 0; + /*! \brief Internal ExecBuilder. */ + relax::ExecBuilder builder_; + /*! \brief Total number of virtual registers allocated. */ + size_t registers_num_ = 0; + /*! \brief Map from var to register number. */ + std::unordered_map var_register_map_; + /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ + const Op& alloc_storage_op_ = Op::Get("relax.vm.builtin.alloc_storage"); + const Op& alloc_tensor_op_ = Op::Get("relax.vm.builtin.alloc_tensor"); + const Op& store_shape_op_ = Op::Get("relax.vm.builtin.store_shape"); + const Op& load_shape_op_ = Op::Get("relax.vm.builtin.load_shape"); +}; + +void VMCompiler::Compile(IRModule mod, Target target, Target target_host) { + builder_ = relax::ExecBuilderNode::Create(); + + IRModule tir_mod; + IRModule rx_mod; + for (auto& p : mod->functions) { + auto gvar = p.first; + + BaseFunc func = p.second; + if (func.as()) { + tir_mod->Add(gvar, func); + } else if (func.as()) { + rx_mod->Add(gvar, func); + } else { + LOG(FATAL) << "Cannot handle such function node now:\n" << func; + } + } + lib_ = tvm::build(tir_mod, target, target_host); + + CodeGenVM compiler(builder_.operator->()); + for (auto& p : rx_mod->functions) { + compiler.VisitExpr(p.second); + } +} + +Executable VMCompiler::GetExec() { + return builder_->Get(); +} + +runtime::Module VMCompiler::GetLib() { + return lib_; +} + +Array Build(IRModule mod, Target target, Target target_host) { + auto compiler = make_object(); + compiler->Compile(mod, target, target_host); + Executable exec = compiler->GetExec(); + Module lib = compiler->GetLib(); + return Array({exec, lib}); +} + +TVM_REGISTER_GLOBAL("relax.VMBuild") +.set_body_typed(Build); + +} // namespace relax_vm +} // namespace relax +} // namespace tvm diff --git a/src/relax/vm/compiler.h b/src/relax/backend/vm/codegen_vm.h similarity index 74% rename from src/relax/vm/compiler.h rename to src/relax/backend/vm/codegen_vm.h index 55036c4e6f37..36c78f7e3ec6 100644 --- a/src/relax/vm/compiler.h +++ b/src/relax/backend/vm/codegen_vm.h @@ -18,27 +18,29 @@ */ /*! - * \file src/relax/vm/compiler.h - * \brief A compiler to compile a relay::Module to the VM executable. + * \file src/relax/backend/vm/codegen_vm.h + * \brief A compiler to compile an IRModule to VM executable. */ -#ifndef TVM_RELAX_VM_COMPILER_H_ -#define TVM_RELAX_VM_COMPILER_H_ +#ifndef TVM_RELAX_BACKEND_VM_COMPILER_H_ +#define TVM_RELAX_BACKEND_VM_COMPILER_H_ -#include #include #include #include +#include #include namespace tvm { -namespace runtime { +namespace relax { namespace relax_vm { using tvm::Target; +using namespace tvm::runtime::relax_vm; +using namespace tvm::runtime; -class VMCompiler : public runtime::ModuleNode { +class VMCompiler : public Object { public: /*! * \brief Compile the functions in a Module. @@ -56,9 +58,9 @@ class VMCompiler : public runtime::ModuleNode { */ Module GetLib(); - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - - const char* type_key() const { return "relax.VMCompiler"; } + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "relax.VMCompiler"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExecutableNode, Object); protected: /*! \brief Internal executable builder. */ @@ -68,7 +70,7 @@ class VMCompiler : public runtime::ModuleNode { }; } // namespace relax_vm -} // namespace runtime +} // namespace relax } // namespace tvm -#endif // TVM_RELAX_VM_COMPILER_H_ +#endif // TVM_RELAX_BACKEND_VM_COMPILER_H_ diff --git a/src/relax/transform/memory_rewrite.cc b/src/relax/backend/vm/vm_memory_lower.cc similarity index 72% rename from src/relax/transform/memory_rewrite.cc rename to src/relax/backend/vm/vm_memory_lower.cc index c80ecc088981..c994aee7bc18 100644 --- a/src/relax/transform/memory_rewrite.cc +++ b/src/relax/backend/vm/vm_memory_lower.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/relax/transform/memory_rewrite.cc + * \file src/relax/backend/vm/vm_memory_lower.cc * \brief */ #include @@ -25,10 +25,11 @@ #include #include -#include "../../relay/transforms/pattern_utils.h" +#include "../../../relay/transforms/pattern_utils.h" namespace tvm { namespace relax { +namespace vm { // ================== // MemLowerMutator @@ -36,13 +37,12 @@ namespace relax { // Example: // x = relax.builtin.alloc_tensor((m, n)) // --> -// gv0 = relax.call_packed("vm.builtin.alloc_storage", (m * n), alignment, device_type, -// relax.attrs.AllocStorageAttrs) gv1 = relax.call_packed("vm.builtin.alloc_tensor", gv0, offset, -// (m, n), relax.attrs.AllocTensorAttrs) +// gv0 = relax.call_packed("relax.vm.builtin.alloc_storage", (m * n), relax.attrs.AllocStorageAttrs) +// gv1 = relax.call_packed("relax.vm.builtin.alloc_tensor", gv0, (m, n), relax.attrs.AllocTensorAttrs) -class MemLowerMutator : public ExprMutator { +class VMMemLowerMutator : public ExprMutator { public: - explicit MemLowerMutator(IRModule mod) { mod_ = mod; } + explicit VMMemLowerMutator(IRModule mod) { mod_ = mod; } IRModule Lower() { IRModule ret_mod = IRModule(); @@ -87,30 +87,27 @@ class MemLowerMutator : public ExprMutator { call = expr.as(); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& vm_alloc_storage_op = Op::Get("relax.vm.builtin.alloc_storage"); + static const Op& vm_alloc_tensor_op = Op::Get("relax.vm.builtin.alloc_tensor"); + // TODO(@yuchen): memory planning if (call->op == alloc_tensor_op) { - ShapeExpr tensor_shape = Downcast(call->args[0]); + ShapeExpr output_shape = Downcast(call->args[0]); + // TODO(@yuchen): Get the type of input x, options: add an attr to relax.builtin.alloc_tensor - Type tensor_type = DynTensorType(tensor_shape->values.size(), DataType::Float(32)); - Expr storage_size = ComputeStorageSize(tensor_shape, tensor_type); - ShapeExpr alignment = ShapeExpr({IntImm(DataType::Int(64), 64)}); - ShapeExpr device_type = ShapeExpr({IntImm(DataType::Int(64), 1)}); + Type tensor_type = DynTensorType(output_shape->values.size(), DataType::Float(32)); + Expr storage_size = ComputeStorageSize(output_shape, tensor_type); auto storage_attr = make_object(); storage_attr->dtype = DataType::Float(32); storage_attr->device_type = 1; - Var storage = - builder_->Emit(Call(ExternFunc("vm.builtin.alloc_storage"), - {storage_size, alignment}, Attrs(storage_attr)), - "storage"); - - ShapeExpr offset = ShapeExpr({IntImm(DataType::Int(64), 0)}); + Var storage = builder_->Emit(Call(vm_alloc_storage_op, {storage_size}, Attrs(storage_attr)), "storage"); auto tensor_attr = make_object(); + tensor_attr->offset = 0; tensor_attr->dtype = DataType::Float(32); Expr shape = call->args[0]; - return builder_->Emit( - Call(ExternFunc("vm.builtin.alloc_tensor"), {storage, offset, shape}, Attrs(tensor_attr)), - "tensor"); + Var tensor = builder_->Emit(Call(vm_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr)), "tensor"); + return tensor; } return GetRef(call); @@ -120,9 +117,11 @@ class MemLowerMutator : public ExprMutator { IRModule mod_; }; -TVM_REGISTER_GLOBAL("relax.transform.memory_lower").set_body_typed([](IRModule mod) { - return MemLowerMutator(mod).Lower(); +TVM_REGISTER_GLOBAL("relax.transform.vm_memory_lower") +.set_body_typed([](IRModule mod) { + return VMMemLowerMutator(mod).Lower(); }); +} // namespace vm } // namespace relax } // namespace tvm diff --git a/src/relax/transform/shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc similarity index 83% rename from src/relax/transform/shape_lower.cc rename to src/relax/backend/vm/vm_shape_lower.cc index 6f58ab2df3c6..25e02785ec27 100644 --- a/src/relax/transform/shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file src/relax/transform/shape_lower.cc + * \file src/relax/backend/vm/vm_shape_lower.cc * \brief */ #include @@ -25,19 +25,19 @@ #include #include #include - -#include "../../printer/text_printer.h" +#include namespace tvm { namespace relax { +namespace vm { -class ShapeLowerMutator : public ExprMutator { +class VMShapeLowerMutator : public ExprMutator { public: static DataType ShapeDType() { return DataType::Int(64); }; - explicit ShapeLowerMutator(IRModule mod) { mod_ = mod; } + explicit VMShapeLowerMutator(IRModule mod) { mod_ = mod; } IRModule Lower() { ret_mod_ = IRModule(); @@ -46,7 +46,6 @@ class ShapeLowerMutator : public ExprMutator { if (p.second->IsInstance()) { // prepare mapping and heap var expr2slot_ = PrepareExpr2Slot(Downcast(func)); - // LOG(INFO) << "mapping: " << expr2slot_; heap_size_ = IntImm(ShapeDType(), expr2slot_.size()); DynTensorType heap_type(1, ShapeDType()); shape_heap_ = Var("shape_heap", ShapeExpr({heap_size_}), heap_type); @@ -61,14 +60,17 @@ class ShapeLowerMutator : public ExprMutator { void VisitMatchShape(const MatchShape& binding) override { Expr shape = ExprMutator::VisitExpr(binding->value); + static const Op& store_shape_op = Op::Get("relax.vm.builtin.store_shape"); + auto store_shape_attr = make_object(); + Array pattern = binding->pattern; - Array indices; + Array indices; for (size_t i = 0; i < pattern.size(); ++i) { - IntImm idx = expr2slot_.at(pattern[i]); + int idx = expr2slot_.at(pattern[i]); indices.push_back(idx); } - builder_->Emit(Call(ExternFunc("vm.builtin.decode_shape"), - {shape, shape_heap_, ShapeExpr(indices)}), "_decode_shape"); + store_shape_attr->indices = indices; + builder_->Emit(Call(store_shape_op, {shape, shape_heap_}, Attrs(store_shape_attr)), "gv"); } Expr VisitExpr_(const ShapeExprNode* node) override { @@ -84,12 +86,15 @@ class ShapeLowerMutator : public ExprMutator { ret_mod_->Add(shape_func_var, func); // construct shape - Array indices; + Array indices; for (PrimExpr e : node->values) { indices.push_back(expr2slot_.at(e)); } - return builder_->Emit(Call(ExternFunc("vm.builtin.make_shape"), - {shape_heap_, ShapeExpr(indices)}), "sh"); + static const Op& load_shape_op = Op::Get("relax.vm.builtin.load_shape"); + auto load_shape_attr = make_object(); + load_shape_attr->indices = indices; + + return builder_->Emit(Call(load_shape_op, {shape_heap_}, Attrs(load_shape_attr)), "sh"); } Expr VisitExpr_(const FunctionNode* node) override { @@ -134,7 +139,7 @@ class ShapeLowerMutator : public ExprMutator { for (PrimExpr e : s->values) { Map var_mapping = BuildVarMapping(e, buffer); PrimExpr value = tir::Substitute(e, var_mapping); - IntImm idx = expr2slot_.at(e); + int idx = expr2slot_.at(e); seq.push_back(tir::Store(buffer->data, value, idx, tir::const_true())); } tir::Stmt body = tir::SeqStmt(seq); @@ -157,16 +162,15 @@ class ShapeLowerMutator : public ExprMutator { return ret; } - Map PrepareExpr2Slot(Function expr) const { + Map PrepareExpr2Slot(Function expr) const { int cnt = 0; - Map ret; + Map ret; auto func = [&](const Expr& e) { if (e->IsInstance()) { ShapeExpr shape = Downcast(e); for (auto prim_e : shape->values) { if (ret.count(prim_e) == 0) { - IntImm idx(ShapeDType(), cnt++); - ret.Set(prim_e, idx); + ret.Set(prim_e, cnt++); } } } @@ -192,13 +196,14 @@ class ShapeLowerMutator : public ExprMutator { // function-wise members IntImm heap_size_; Var shape_heap_; - Map expr2slot_; + Map expr2slot_; }; -TVM_REGISTER_GLOBAL("relax.transform.shape_lower") +TVM_REGISTER_GLOBAL("relax.transform.vm_shape_lower") .set_body_typed([](IRModule mod) { - return ShapeLowerMutator(mod).Lower(); + return VMShapeLowerMutator(mod).Lower(); }); +} // namespace vm } // namespace relax } // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index c7e1e58419de..e27626e27f8d 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -92,5 +93,66 @@ Expr MakeAllocTensor(Expr shape) { TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor") .set_body_typed(MakeAllocTensor); +// vm alloc_storage + +RELAY_REGISTER_OP("relax.vm.builtin.alloc_storage") +.set_attrs_type() +.set_num_inputs(1) +.add_argument("size", "Expr", "The size of the storage to allocate."); + +Expr MakeVMAllocStorage(Expr size) { + static const Op& op = Op::Get("relax.vm.builtin.alloc_storage"); + return Call(op, {size}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_storage") +.set_body_typed(MakeVMAllocStorage); + +// vm alloc_tensor + +RELAY_REGISTER_OP("relax.vm.builtin.alloc_tensor") +.set_attrs_type() +.set_num_inputs(1) +.add_argument("shape", "Expr", "The shape of the tensor to allocate."); + +Expr MakeVMAllocTensor(Expr shape) { + static const Op& op = Op::Get("relax.vm.builtin.alloc_tensor"); + return Call(op, {shape}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_tensor") +.set_body_typed(MakeVMAllocTensor); + +// vm store_shape + +RELAY_REGISTER_OP("relax.vm.builtin.store_shape") +.set_attrs_type() +.set_num_inputs(2) +.add_argument("shape", "Expr", "The shape to be stored.") +.add_argument("heap", "Expr", "The heap to store the shape."); + +Expr MakeStoreShape(Expr shape, Expr heap) { + static const Op& op = Op::Get("relax.vm.builtin.store_shape"); + return Call(op, {shape, heap}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.store_shape") +.set_body_typed(MakeStoreShape); + +// vm load_shape + +RELAY_REGISTER_OP("relax.vm.builtin.load_shape") +.set_attrs_type() +.set_num_inputs(1) +.add_argument("heap", "Expr", "The heap to load the shape from."); + +Expr MakeLoadShape(Expr heap) { + static const Op& op = Op::Get("relax.vm.builtin.load_shape"); + return Call(op, {heap}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.vm.builtin.load_shape") +.set_body_typed(MakeLoadShape); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/call_dps_rewrite.cc b/src/relax/transform/call_dps_rewrite.cc index 453ce198a2d1..76ef84589817 100644 --- a/src/relax/transform/call_dps_rewrite.cc +++ b/src/relax/transform/call_dps_rewrite.cc @@ -32,11 +32,12 @@ namespace relax { // ================== // CallDPSMutator +// Perform explicit tensor allocation for call_dps. // Example: -// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) +// lv0: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x)) // --> -// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) -// rx.call_packed(op.identity, x, lv0) +// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m]) +// rx.call_packed(op.identity, x, gv0) class CallDPSMutator : public ExprMutator { public: @@ -58,8 +59,6 @@ class CallDPSMutator : public ExprMutator { // post-order mutation Expr expr = ExprMutator::VisitExpr_(call); call = expr.as(); - // TODO(@yuchen, @altanh): using mutate cause infinite recursion - // Expr expr = ExprMutator::Mutate(GetRef(call)); static const Op& call_dps_op = Op::Get("relax.call_dps"); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index 97de0f2370f2..b847a2fe8988 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -46,7 +47,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.alloc_shape_heap") return NDArray::Empty(size, DLDataType{kDLInt, 64, 1}, DLDevice{kDLCPU, 0}); }); -TVM_REGISTER_GLOBAL("vm.builtin.decode_shape") +TVM_REGISTER_GLOBAL("vm.builtin.store_shape") .set_body_typed([](ShapeTuple shape, NDArray heap, ShapeTuple indexes) { int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); for (size_t i = 0; i < indexes.size(); ++i) { @@ -56,7 +57,7 @@ TVM_REGISTER_GLOBAL("vm.builtin.decode_shape") } }); -TVM_REGISTER_GLOBAL("vm.builtin.make_shape") +TVM_REGISTER_GLOBAL("vm.builtin.load_shape") .set_body_typed([](NDArray heap, ShapeTuple indexes) { int64_t* heap_data = reinterpret_cast(heap.ToDLPack()->dl_tensor.data); std::vector shape; @@ -69,14 +70,12 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_shape") }); TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") -.set_body_typed([](void* vm_state_ptr, ShapeTuple buffer_size, ShapeTuple alignment, Index device_type, - DLDataType dtype_hint) { +.set_body_typed([](void* vm_state_ptr, ShapeTuple buffer_size, Index device_type, DLDataType dtype_hint) { + int alignment = runtime::kAllocAlignment; ICHECK_EQ(buffer_size.size(), 1); - ICHECK_EQ(alignment.size(), 1); VMState* vm_state = static_cast(vm_state_ptr); int64_t size_imm = buffer_size[0]; - int64_t align_imm = alignment[0]; - DLOG(INFO) << "AllocStorage: allocation_size=" << size_imm << ", alignment=" << align_imm + DLOG(INFO) << "AllocStorage: allocation_size=" << size_imm << ", alignment=" << alignment << ", dtype_hint=" << runtime::DLDataType2String(dtype_hint) << ", device_type=" << device_type; @@ -85,16 +84,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.alloc_storage") << "Memory allocator for device " << device_type << " has not been initialized"; auto* alloc = vm_state->allocators[device_type]; ICHECK(alloc) << "Did you forget to init the VirtualMachine with devices?"; - storage_obj->buffer = alloc->Alloc(size_imm, align_imm, dtype_hint); + storage_obj->buffer = alloc->Alloc(size_imm, alignment, dtype_hint); Storage storage(storage_obj); return storage; }); TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor") -.set_body_typed([](Storage storage, ShapeTuple offset, ShapeTuple shape, DLDataType dtype) { - ICHECK_EQ(offset.size(), 1); - int64_t offset_imm = offset[0]; - auto tensor = storage->AllocNDArray(offset_imm, shape, dtype); +.set_body_typed([](Storage storage, ShapeTuple shape, Index offset, DLDataType dtype) { + auto tensor = storage->AllocNDArray(offset, shape, dtype); return tensor; }); diff --git a/src/relax/vm/compiler.cc b/src/relax/vm/compiler.cc deleted file mode 100644 index 595ba9255a29..000000000000 --- a/src/relax/vm/compiler.cc +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relax/vm/compiler.cc - * \brief A compiler from relay::Module to the VM byte code. - */ - -#include "compiler.h" - -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace runtime { -namespace relax_vm { - -using namespace relax; - -class VMCompilerImpl : public ExprVisitor { - public: - explicit VMCompilerImpl(ExecBuilderNode* builder) { - builder_ = GetRef(builder); - } - - protected: - /*! \brief A counter for naming local functions. */ - int local_func_counter_ = 0; - - // TODO(@yuchen): support visiting other IR nodes - void VisitExpr_(const FunctionNode* func_node) { - if (func_node->name.defined()) { - builder_->EmitFunction(func_node->name.value()->name_hint, func_node->params.size()); - } else { - // TODO(@yuchen): handle local functions that capture local vars outside the func - // TODO(@yuchen): a renaming pass to resolve name conflicts, e.g. the input module has a - // function named "local_funcN" - // lift the local func to a global func and compile it normally - builder_->EmitFunction("local_func" + std::to_string(local_func_counter_++), - func_node->params.size()); - } - for (auto param : func_node->params) { - NewRegister(param); - } - ExprVisitor::VisitExpr_(func_node); - } - - void VisitExpr_(const SeqExprNode* op) { - for (auto block : op->blocks) { - this->VisitBindingBlock(block); - } - // find the function return value and emit the output - auto ret_reg = this->var_register_map_.find(Downcast(op->body)); - ICHECK(ret_reg != this->var_register_map_.end()); - builder_->EmitRet(ret_reg->second); - } - - // TODO: visit call node - void VisitVarBinding(const VarBinding& binding) { - Var var = binding->var; - // TODO(@yuchen): support other nodes than Call - if (binding->value.as()){ - Call call_node = Downcast(binding->value); - if (auto* extern_func = call_node->op.as()) { - String name = extern_func->global_symbol; - if (name == "vm.builtin.alloc_storage") { - EmitAllocStorage(call_node, var); - } else if (name == "vm.builtin.alloc_tensor") { - EmitAllocTensor(call_node, var); - } else { - // Normal packed function without attributes - std::vector args = ConvertArgs(call_node); - // TODO(@yuchen): what if the packed func has void return (no need to write to the dst - // register)? - builder_->EmitCall(name, args, NewRegister(var)); - } - } else if (auto* gvar = call_node->op.as()) { - String name = gvar->name_hint; - std::vector args = ConvertArgs(call_node); - // TODO: global_var mangling - builder_->EmitCall(name, args, NewRegister(var)); - } else { - LOG(FATAL) << "TODO: support compiling everything other than extern functions."; - } - } else if (const VarNode* var_node = binding->value.as()) { - const Var& rhs_var = GetRef(var_node); - auto rhs_var_reg = this->var_register_map_.find(rhs_var); - ICHECK(rhs_var_reg != this->var_register_map_.end()); - this->var_register_map_.insert({var, rhs_var_reg->second}); - } else { - LOG(FATAL) << "TODO: support compiling everything other than Call and Var."; - } - } - - void EmitAllocStorage(const Call& call_node, const Var& var) { - Attrs attrs = call_node->attrs; - - // Get dtype and device_type from the attributes. - auto alloc_attrs = attrs.as(); - ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; - DataType dtype = alloc_attrs->dtype; - int device_type = alloc_attrs->device_type; - - std::vector args; - args.push_back(Instruction::Arg(Instruction::kVMStateRegister)); - for (Expr arg: call_node->args) { - args.push_back(ConvertArg(arg)); - } - args.push_back(Instruction::Arg(Instruction::kImmediate, device_type)); - - // store dtype in constant pool - TVMRetValue data_type; - data_type = dtype; - Index index = this->builder_->EmitConstant(data_type); - args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); - - builder_->EmitCall("vm.builtin.alloc_storage", args, NewRegister(var)); - } - - void EmitAllocTensor(const Call& call_node, const Var& var) { - Attrs attrs = call_node->attrs; - - // Get dtype from the attributes. - auto alloc_attrs = attrs.as(); - ICHECK(alloc_attrs != nullptr) << "must be the AllocTensor attrs"; - DataType dtype = alloc_attrs->dtype; - - std::vector args; - for (Expr arg: call_node->args) { - args.push_back(ConvertArg(arg)); - } - - // store dtype in constant pool - TVMRetValue data_type; - data_type = dtype; - Index index = builder_->EmitConstant(data_type); - args.push_back(Instruction::Arg(Instruction::kConstIdx, index)); - - builder_->EmitCall("vm.builtin.alloc_tensor", args, NewRegister(var)); - } - - size_t NewRegister(Var var) { - size_t reg = this->registers_num_++; - this->var_register_map_.insert({var, reg}); - return reg; - } - - bool IsConstantShape(ShapeExpr shape) const { - for (PrimExpr e : shape->values) { - if (!e->IsInstance()) { - return false; - } - } - return true; - } - - // TODO: recursive Expr -> instr::arg, ExprFunctor, like llvm builder - Instruction::Arg ConvertArg(Expr arg) { - if (arg->IsInstance()) { - Var var = Downcast(arg); - auto reg = this->var_register_map_.find(Downcast(arg)); - ICHECK(reg != this->var_register_map_.end()) - << var->name_hint() << "(" << var << ")" << " not in the register map."; - return Instruction::Arg(Instruction::kRegister, reg->second); - } else if (arg->IsInstance()) { - ShapeExpr sh = Downcast(arg); - ICHECK(IsConstantShape(sh)) - << "should only use constant shape after shape lowering: " - << sh->values; - std::vector shape; - for (PrimExpr e : sh->values) { - shape.push_back(Downcast(e)->value); - } - auto shape_tuple = ShapeTuple(shape); - TVMRetValue shape_tuple_value; - shape_tuple_value = shape_tuple; - Index index = builder_->EmitConstant(shape_tuple_value); - return Instruction::Arg(Instruction::kConstIdx, index); - } else { - LOG(FATAL) << "not supported argument type."; - } - return Instruction::Arg(); - } - - std::vector ConvertArgs(const Call& call) { - std::vector ret; - for (size_t i = 0; i < call->args.size(); ++i) { - ret.push_back(ConvertArg(call->args[i])); - } - return ret; - } - - /*! \brief Internal ExecBuilder. */ - relax::ExecBuilder builder_; - /*! \brief Total number of virtual registers allocated. */ - size_t registers_num_ = 0; - /*! \brief Map from var to register number. */ - std::unordered_map var_register_map_; -}; - -PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { - if (name == "compile") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 3); - IRModule mod = args[0]; - this->Compile(mod, args[1], args[2]); - }); - } else if (name == "get_executable") { - return PackedFunc([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetExec(); }); - } else { - LOG(FATAL) << "Unknown packed function: " << name; - return PackedFunc([name](TVMArgs args, TVMRetValue* rv) {}); - } -} - -void VMCompiler::Compile(IRModule mod, Target target, Target target_host) { - // Reset internal builder - builder_ = relax::ExecBuilderNode::Create(); - - IRModule tir_mod; - IRModule rx_mod; - for (auto& p : mod->functions) { - auto gvar = p.first; - - BaseFunc func = p.second; - if (func.as()) { - tir_mod->Add(gvar, func); - } else if (func.as()) { - rx_mod->Add(gvar, func); - } else { - LOG(FATAL) << "Cannot handle such function node now:\n" << func; - } - } - lib_ = tvm::build(tir_mod, target, target_host); - - VMCompilerImpl compiler(builder_.operator->()); - for (auto& p : rx_mod->functions) { - compiler.VisitExpr(p.second); - } -} - -Executable VMCompiler::GetExec() { - return builder_->Get(); -} - -runtime::Module VMCompiler::GetLib() { - return lib_; -} - -runtime::Module CreateVMCompiler() { - auto compiler = make_object(); - return runtime::Module(compiler); -} - -Array Build(IRModule mod, Target target, Target target_host) { - auto compiler = make_object(); - compiler->Compile(mod, target, target_host); - Executable exec = compiler->GetExec(); - Module lib = compiler->GetLib(); - return Array({exec, lib}); -} - -TVM_REGISTER_GLOBAL("relax.VMBuild") -.set_body_typed(Build); - -} // namespace relax_vm -} // namespace runtime -} // namespace tvm diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 247d7b5662e6..fa6a9e3ee6e4 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -24,8 +24,6 @@ import tvm.script from tvm.script import relax as R -import numpy as np - def test_fma_rewrite(): m = tir.Var("m", "int32") @@ -68,7 +66,7 @@ def test_fma_rewrite(): def test_to_non_dataflow(): @tvm.script.ir_module - class TestToNoneDataflow: + class TestToNonDataflow: @R.function def foo(x: Tensor[(m, n), "float32"]): with relax.dataflow(): @@ -77,7 +75,7 @@ def foo(x: Tensor[(m, n), "float32"]): relax.output(gv1) return gv1 - mod = TestToNoneDataflow + mod = TestToNonDataflow old_vars = [] @@ -92,12 +90,10 @@ def fvisit(e): new_mod = relax.transform.to_non_dataflow(mod) new_vars = [] - def fvisit(e): if isinstance(e, relax.Var): nonlocal new_vars new_vars.append(e) - relax.analysis.post_order_visit(new_mod["foo"], fvisit) assert x == new_vars[1] @@ -119,7 +115,6 @@ def foo(x: Tensor[(m, n), "float32"]): return gv0 mod = TestCallDpsRewrite - code = R.parser.astext(mod) # before rewrite v0 = mod["foo"].body.blocks[0].bindings[0].var @@ -130,9 +125,7 @@ def foo(x: Tensor[(m, n), "float32"]): # after rewrite new_mod = relax.transform.call_dps_rewrite(mod) func = new_mod["foo"] - code = R.parser.astext(new_mod) - # the dataflow block has changed to binding block due to the rewriting block = func.body.blocks[0] assert not isinstance(block, relax.DataflowBlock) @@ -145,51 +138,75 @@ def foo(x: Tensor[(m, n), "float32"]): assert s2.op.global_symbol == "test.op.identity" -def test_memory_lower(): +def test_vm_memory_lower(): @tvm.script.ir_module - class TestMemoryLower: + class TestVMMemoryLower: @R.function def foo(x: Tensor[(m, n), "float32"]): alloc = relax.builtin.alloc_tensor((m, n)) _ = relax.call_packed("test.op.identity", (x,), alloc) gv0 = alloc return gv0 + + mod = TestVMMemoryLower - mod = TestMemoryLower - - # after memory lowering - new_mod = relax.transform.memory_lower(mod) + # after vm memory lowering + new_mod = relax.transform.vm_memory_lower(mod) + func = new_mod["foo"] assert isinstance(new_mod, tvm.IRModule) - assert isinstance(new_mod["foo"], tvm.relax.expr.Function) - code = R.parser.astext(new_mod) - assert "vm.builtin.alloc_storage" in code - assert "vm.builtin.alloc_tensor" in code + assert isinstance(func, tvm.relax.expr.Function) + + block = func.body.blocks[0] + s1 = block.bindings[0].value + assert isinstance(s1, tvm.relay.Call) + assert s1.op.name == "relax.vm.builtin.alloc_storage" + s2 = block.bindings[1].value + assert isinstance(s2, tvm.relay.Call) + s4 = block.bindings[3].value + assert isinstance(s4, tvm.relay.Call) + assert isinstance(s4.op, relax.ExternFunc) + assert s4.op.global_symbol == "test.op.identity" -def test_shape_lowering(): +def test_vm_shape_lowering(): @tvm.script.ir_module - class TestShapeLower: + class TestVMShapeLower: @R.function def foo(x: Tensor[_, "float32"]) -> Shape: sh = relax.call_packed("vm.builtin.shape_of", x) relax.match_shape(sh, (n, m)) return (n * 2, m * 3) - mod = TestShapeLower - new_mod = relax.transform.shape_lower(mod) + mod = TestVMShapeLower + + # after vm shape lowering + new_mod = relax.transform.vm_shape_lower(mod) + assert isinstance(new_mod, tvm.IRModule) assert isinstance(new_mod["shape_func"], tvm.tir.function.PrimFunc) - assert isinstance(new_mod["foo"], tvm.relax.expr.Function) - code = R.parser.astext(new_mod) - assert "alloc_shape_heap" in code - assert "decode_shape" in code - assert "make_shape" in code - + func = new_mod["foo"] + assert isinstance(func, tvm.relax.expr.Function) + + s1 = func.body.blocks[0].bindings[0].value + assert isinstance(s1.op, relax.ExternFunc) + assert s1.op.global_symbol == "vm.builtin.alloc_shape_heap" + s2 = func.body.blocks[1].bindings[0].value + assert isinstance(s2.op, relax.ExternFunc) + assert s2.op.global_symbol == "vm.builtin.shape_of" + s3 = func.body.blocks[1].bindings[1].value + assert isinstance(s3, tvm.relay.Call) + assert s3.op.name == "relax.vm.builtin.store_shape" + s4 = func.body.blocks[2].bindings[0].value + assert isinstance(s4.op, relax.GlobalVar) + assert s4.op.name_hint == "shape_func" + s5 = func.body.blocks[2].bindings[1].value + assert isinstance(s5, tvm.relay.Call) + assert s5.op.name == "relax.vm.builtin.load_shape" if __name__ == "__main__": test_fma_rewrite() test_to_non_dataflow() test_call_dps_rewrite() - test_memory_lower() - test_shape_lowering() + test_vm_memory_lower() + test_vm_shape_lowering() diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index a8b62b2552bb..66526fbf28eb 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -126,12 +126,8 @@ def test_vm_constant_serialize(): inp = tvm.nd.array(np.random.rand(4, 6).astype(np.float32)) ib = relax.ExecBuilder() with ib.function("main", num_inputs=1): - ib.emit_call( - "vm.builtin.alloc_storage", - args=[ib.vm_state(), (24,), (8,), ib.imm(1), dtype], - dst=ib.r(1), - ) - ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), (0,), shape, dtype], dst=ib.r(2)) + ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(1), dtype], dst=ib.r(1)) + ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), shape, ib.imm(0), dtype], dst=ib.r(2)) ib.emit_call("test.vm.identity", args=[ib.r(0), ib.r(2)]) ib.emit_ret(ib.r(2)) exec0 = ib.get() @@ -220,12 +216,8 @@ def test_vm_storage(): shape = (4, 6) ib = relax.ExecBuilder() with ib.function("main", num_inputs=0): - ib.emit_call( - "vm.builtin.alloc_storage", - args=[ib.vm_state(), (24,), (8,), ib.imm(1), dtype], - dst=ib.r(1), - ) - ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), (0,), shape, dtype], dst=ib.r(2)) + ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), (24,), ib.imm(1), dtype], dst=ib.r(1)) + ib.emit_call("vm.builtin.alloc_tensor", args=[ib.r(1), shape, ib.imm(0), dtype], dst=ib.r(2)) ib.emit_ret(ib.r(2)) ex = ib.get() vm = relax.VirtualMachine(ex, tvm.cpu()) @@ -239,34 +231,19 @@ def test_vm_compile_stage0(): @tvm.script.ir_module class TestVMCompileStage0: @R.function - def foo(x: Tensor[(3, 4), "float32"]): - y = relax.call_packed( - "vm.builtin.alloc_storage", - (12,), - (64,), - device_id=0, - device_type=1, - attrs_type_key="relax.attrs.AllocStorageAttrs", - ) - z = relax.call_packed( - "vm.builtin.alloc_tensor", - y, - (0,), - (3, 4), - attrs_type_key="relax.attrs.AllocTensorAttrs", - ) - w = relax.call_packed("test.vm.identity", x, z) - return z + def foo(x: Tensor[(3, 4), "float32"], y: Tensor[(3, 4), "float32"]): + z = relax.call_packed("test.vm.identity", x, y) + return y mod = TestVMCompileStage0 target = tvm.target.Target("llvm") target_host = tvm.target.Target("llvm") ex, lib = relax.vm.build(mod, target, target_host) - inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp1 = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3,4).astype(np.float32)) vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib) - res = vm["foo"](inp) - np.testing.assert_allclose(inp.asnumpy(), res.asnumpy()) - res = vm["foo"](inp) + vm["foo"](inp1, inp2) + np.testing.assert_allclose(inp2.asnumpy(), inp1.asnumpy()) def test_vm_compile_stage1(): @@ -299,9 +276,9 @@ def foo(x: Tensor[_, "float32"]) -> Shape: "vm.builtin.alloc_shape_heap", (4,) ) gv0 = relax.call_packed("vm.builtin.shape_of", x) - gv1 = relax.call_packed("vm.builtin.decode_shape", gv0, shape_heap, (0, 1)) + gv1 = relax.call_packed("vm.builtin.store_shape", gv0, shape_heap, (0, 1)) gv2 = shape_func0(shape_heap) - gv3 = relax.call_packed("vm.builtin.make_shape", shape_heap, (2, 3)) + gv3 = relax.call_packed("vm.builtin.load_shape", shape_heap, (2, 3)) return gv3 """