diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 5d973e23a70b..bf883e6d2f4a 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -32,6 +32,9 @@ namespace tvm { namespace relax { using relay::Id; +using relay::Call; +using relay::Tuple; +using relay::TupleGetItem; using ExprNode = RelayExprNode; using Expr = RelayExpr; @@ -388,7 +391,7 @@ class FunctionNode : public BaseFuncNode { static constexpr const char* _type_key = "relax.expr.Function"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, ExprNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; class Function : public Expr { @@ -398,6 +401,37 @@ class Function : public Expr { TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode); }; + +/*! \brief The extern function, which can represent packed function. */ +class ExternFuncNode : public BaseFuncNode { + public: + /*! \brief The name of global symbol. */ + String global_symbol; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("global_symbol", &global_symbol); + } + + bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const { + return equal(global_symbol, other->global_symbol); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(global_symbol); + } + + static constexpr const char* _type_key = "relax.expr.ExternFunc"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(ExternFuncNode, BaseFuncNode); +}; + +class ExternFunc : public Expr { + public: + TVM_DLL ExternFunc(String global_symbol); + TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode); +}; + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index ed324cd60e4e..cf2f4c0c751a 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -19,6 +19,7 @@ from . import expr from . import ty from . import vm +from . import op # Expr @@ -36,10 +37,13 @@ DataflowBlock = expr.DataflowBlock SeqExpr = expr.SeqExpr ShapeExpr = expr.ShapeExpr +Tuple = expr.Tuple Function = expr.Function +ExternFunc = expr.ExternFunc # helper functions const = expr.const +extern = expr.extern # Type ShapeType = ty.ShapeType @@ -49,3 +53,6 @@ ExecBuilder = exec_builder.ExecBuilder VirtualMachine = vm.VirtualMachine load_exec_from_file = vm.load_exec_from_file + +# Operator +from .op.base import call_dps diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index b68aa34a869a..3681d9582685 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -16,15 +16,17 @@ # under the License. from typing import List, Optional, Union, Dict import tvm._ffi -from ..ir.base import Node, Span, SourceName -from ..relay.base import Id +from ..ir import Node, Span, SourceName, BaseFunc +from ..runtime import String +from ..relay import Id, Tuple, TupleGetItem from ..tir import PrimExpr from . import _ffi_api from .. import relay -GlobalVar = relay.GlobalVar Expr = relay.Expr Type = relay.Type +GlobalVar = relay.GlobalVar +Call = relay.Call const = relay.const @@ -106,7 +108,7 @@ def __init__(self, blocks: List[BindingBlock], body: Expr) -> None: @tvm._ffi.register_object("relax.expr.Function") -class Function(Expr): +class Function(BaseFunc): name: Optional[GlobalVar] params: List[Var] body: Expr @@ -116,3 +118,14 @@ def __init__(self, params: List[Var], body: Expr, ret_type: Type, name: Optional[GlobalVar] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Function, name, params, body, ret_type) + + +@tvm._ffi.register_object("relax.expr.ExternFunc") +class ExternFunc(BaseFunc): + global_symbol: String + + def __init__(self, global_symbol: String) -> None: + self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol) + +def extern(name): + return ExternFunc(name) diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py new file mode 100644 index 000000000000..32d469f0400b --- /dev/null +++ b/python/tvm/relax/op/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax core operators.""" + +# Operators +from .base import * diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py new file mode 100644 index 000000000000..c9cc2e3160c0 --- /dev/null +++ b/python/tvm/relax/op/_ffi_api.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +import tvm._ffi + +tvm._ffi._init_api("relax.op", __name__) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py new file mode 100644 index 000000000000..767dbdea15b6 --- /dev/null +++ b/python/tvm/relax/op/base.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +from ...ir import BaseFunc +from ..expr import Expr, ShapeExpr, Tuple, Call +from . import _ffi_api +from typing import Union, List + +def call_dps(shape: Union[ShapeExpr, List[int]], + func: Expr, + args: Union[Tuple, List[Expr]]) -> Call: + """ + Call a destination-passing-style function and return the output. + + Parameters + ---------- + shape: ShapeExpr + The output shape. + + func : ExternFunc or PrimFunc + The destination-passing-style function. + + args : Tuple[Expr] + The input arguments. + + Returns + ------- + ret: Call + A call node for the call_dps operator. + """ + if isinstance(shape, (list, tuple)): + shape = ShapeExpr(shape) + if isinstance(args, (list, tuple)): + args = Tuple(args) + return _ffi_api.call_dps(shape, func, args) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 97842738e5cd..966d7003008e 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -70,6 +70,7 @@ # Span Span = base.Span SourceName = base.SourceName +Id = base.Id # Type Type = ty.Type diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 1b5901fe9a4a..324df1895c64 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include namespace tvm { @@ -155,6 +154,8 @@ TVM_REGISTER_GLOBAL("relax.SeqExpr") }); +TVM_REGISTER_NODE_TYPE(FunctionNode); + Function::Function(runtime::Optional name, Array params, Expr body, @@ -167,8 +168,6 @@ Function::Function(runtime::Optional name, data_ = std::move(n); } -TVM_REGISTER_NODE_TYPE(FunctionNode); - TVM_REGISTER_GLOBAL("relax.Function") .set_body_typed([](runtime::Optional name, Array params, @@ -177,5 +176,18 @@ TVM_REGISTER_GLOBAL("relax.Function") return Function(name, params, body, ret_type); }); +TVM_REGISTER_NODE_TYPE(ExternFuncNode); + +ExternFunc::ExternFunc(String global_symbol) { + ObjectPtr n = make_object(); + n->global_symbol = std::move(global_symbol); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("relax.ExternFunc") +.set_body_typed([](String global_symbol) { + return ExternFunc(global_symbol); +}); + } // namespace relax } // namespace tvm diff --git a/src/relax/op.cc b/src/relax/op.cc new file mode 100644 index 000000000000..c3c8ed232917 --- /dev/null +++ b/src/relax/op.cc @@ -0,0 +1,40 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +#include +#include + +namespace tvm { +namespace relax { + +Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) { + static const Op& op = Op::Get("call_dps"); + return Call(op, {shape, func, args}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.call_dps") +.set_body_typed(MakeCallDPS); + +RELAY_REGISTER_OP("call_dps") +.set_num_inputs(3) +.add_argument("shape", "ShapeExpr", "The output shape.") +.add_argument("func", "Expr", "The destination-passing-style function.") +.add_argument("args", "Tuple", "The input arguments."); + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_op.py b/tests/python/relax/test_op.py new file mode 100644 index 000000000000..8ef5abe5b4ac --- /dev/null +++ b/tests/python/relax/test_op.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm import relax as rx +from tvm.script import ty +from tvm.ir import TensorType +import numpy as np + +@tvm.register_func("test.op.identity") +def identity_packed(a): + return tvm.nd.array(a.asnumpy) + +@tvm.script.tir +def identity_tir(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [54, 96]) + B = tir.match_buffer(b, [54, 96]) + + with tir.block([54, 96], "compute") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + +def test_call_dps() -> None: + shape_anno = [54, 96] + type_anno = rx.DynTensorType(2, "float32") + v0 = rx.Var("v0", shape_anno, type_anno) + v1 = rx.call_dps([54, 96], rx.extern("test.op.identity"), [v0]) + v1 = rx.call_dps([54, 96], identity_tir, [v0]) + + +if __name__ == "__main__": + test_call_dps()