diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5b63016d2f9d8..58630b82f1a9f 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -787,6 +787,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array axis, Array TVM_DLL PrimExpr prod(PrimExpr source, Array axis, Array init = {}, Span span = Span()); +/*! + * \brief Calculate fmod(x, y) + * \param x Left operand. + * \param y Right operand. + * \param span The location of this operation in the source. + * \return The result expression. + */ +TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span()); + /*! * \brief Calculate floor(x) * \param x The input expression. @@ -887,6 +896,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(log2); TVM_DECLARE_INTRIN_UNARY(log10); +TVM_DECLARE_INTRIN_UNARY(log1p); TVM_DECLARE_INTRIN_UNARY(popcount); TVM_DECLARE_INTRIN_UNARY(tan); TVM_DECLARE_INTRIN_UNARY(cos); diff --git a/python/tvm/script/builder/_ffi_api.py b/python/tvm/script/builder/_ffi_api.py index 3410494ded4d7..98d8618ad7f16 100644 --- a/python/tvm/script/builder/_ffi_api.py +++ b/python/tvm/script/builder/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.builder""" import tvm._ffi -tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access +tvm._ffi._init_api("script.builder", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/tir/__init__.py b/python/tvm/script/builder/tir/__init__.py index e6e431e5bced7..4aa0309b8741f 100644 --- a/python/tvm/script/builder/tir/__init__.py +++ b/python/tvm/script/builder/tir/__init__.py @@ -31,3 +31,4 @@ ) from .prim_func_frame import arg, prim_func from .var import Buffer +from .op import * diff --git a/python/tvm/script/builder/tir/_ffi_api.py b/python/tvm/script/builder/tir/_ffi_api.py index 4e40e7261fd34..c0f5204f22ed6 100644 --- a/python/tvm/script/builder/tir/_ffi_api.py +++ b/python/tvm/script/builder/tir/_ffi_api.py @@ -17,4 +17,4 @@ """FFI APIs for tvm.script.builder.tir""" import tvm._ffi -tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access +tvm._ffi._init_api("script.builder.tir", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/builder/tir/op.py b/python/tvm/script/builder/tir/op.py new file mode 100644 index 0000000000000..d70f1f0a29207 --- /dev/null +++ b/python/tvm/script/builder/tir/op.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script TIR Op""" + +from . import _ffi_api + + +from tvm.tir.op import abs, popcount, nextafter, copysign, fmod +from tvm.tir.op import ( + floor, + floordiv, + floormod, + ceil, + round, + trunc, + truncdiv, + truncmod, + nearbyint, +) +from tvm.tir.op import ( + hypot, + ldexp, + power, + exp, + exp2, + exp10, + erf, + sqrt, + rsqrt, + log, + log2, + log10, + log1p, + sigmoid, +) +from tvm.tir.op import isnan, isfinite, isinf +from tvm.tir.op import cos, cosh, sin, sinh, tan, tanh +from tvm.tir.op import acos, acosh, asin, asinh, atan, atanh +from tvm.tir.op import atan2, clz, comm_reducer, infinity, reinterpret +from tvm.tir.op import min_value, max_value, if_then_else +from tvm.tir.op import call_packed, call_extern +from tvm.tir.expr import Select, Ramp, Broadcast, Shuffle +from tvm.tir.generic import cast + + +def boolean(expr): + return _ffi_api.PrimType("bool", expr) + + +def int8(expr): + return _ffi_api.PrimType("int8", expr) + + +def int16(expr): + return _ffi_api.PrimType("int16", expr) + + +def int32(expr): + return _ffi_api.PrimType("int32", expr) + + +def int64(expr): + return _ffi_api.PrimType("int64", expr) + + +def uint8(expr): + return _ffi_api.PrimType("uint8", expr) + + +def uint16(expr): + return _ffi_api.PrimType("uint16", expr) + + +def uint32(expr): + return _ffi_api.PrimType("uint32", expr) + + +def uint64(expr): + return _ffi_api.PrimType("uint64", expr) + + +def float8(expr): + return _ffi_api.PrimType("float8", expr) + + +def float16(expr): + return _ffi_api.PrimType("float16", expr) + + +def float32(expr): + return _ffi_api.PrimType("float32", expr) + + +def float64(expr): + return _ffi_api.PrimType("float64", expr) + + +def min(a, b, span=None): + """Compute the minimum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _ffi_api.min(a, b, span) # type: ignore + + +def max(a, b, span=None): + """Compute the maximum value of two expressions. + + Parameters + ---------- + a : PrimExpr + The left hand operand + + b : PrimExpr + The right hand operand + + span : Optional[Span] + The location of this operator in the source. + + Returns + ------- + res : PrimExpr + The result expression. + + Note + ---- + This is the default integer division behavior in C. + """ + return _ffi_api.max(a, b, span) # type: ignore diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py index 18a8ecd59bbea..4c4163cb941a2 100644 --- a/python/tvm/script/builder/tir/var.py +++ b/python/tvm/script/builder/tir/var.py @@ -20,10 +20,12 @@ from . import _ffi_api -def Buffer( # pylint: disable=invalid-name +def Buffer( # pylint: disable=invalid-name shape, dtype, name="buffer", storage_scope="", ) -> tir.Buffer: - return _ffi_api.Buffer(shape, dtype, name, storage_scope) # pylint: disable=no-member # type: ignore + return _ffi_api.Buffer( + shape, dtype, name, storage_scope + ) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 2d201bb0dab65..173f4b8c4dbee 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -45,16 +45,35 @@ from .function import PrimFunc, TensorIntrin, IndexMap from .op import call_packed, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace +from .op import ( + call_llvm_intrin, + call_llvm_pure_intrin, + ret, + all, + any, + min_value, + max_value, + trace, +) from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot -from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else +from .op import ( + trunc, + abs, + round, + nextafter, + nearbyint, + power, + popcount, + fmod, + if_then_else, +) from .op import isnan, isfinite, isinf, copysign from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod -from .op import comm_reducer, min, max, sum +from .op import comm_reducer, min, max, sum, infinity, reinterpret from .op import q_multiply_shift from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index de3ca5fa8d5b2..dc6058036fe6c 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -124,7 +124,10 @@ def call_pure_extern(dtype, func_name, *args, span=None): The call expression. """ return Call( - dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span + dtype, + Op.get("tir.call_pure_extern"), + convert((StringImm(func_name),) + args), + span, ) @@ -151,7 +154,10 @@ def call_extern(dtype, func_name, *args, span=None): The call expression. """ return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span + dtype, + Op.get("tir.call_extern"), + convert((StringImm(func_name),) + args), + span=span, ) @@ -183,7 +189,11 @@ def call_llvm_intrin(dtype, name, *args, span=None): llvm_id = codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name return call_intrin( - dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span + dtype, + Op.get("tir.call_llvm_intrin"), + tvm.tir.const(llvm_id, "uint32"), + *args, + span=span, ) @@ -367,6 +377,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any: return _ffi_api.max_value(dtype, span) # type: ignore +def infinity(dtype: str, span: Optional[Span] = None) -> Any: + """infinity value of dtype + + Parameters + ---------- + dtype : str + The data type. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The infinity value of dtype. + """ + return _ffi_api.infinity(dtype, span) # type: ignore + + +def reinterpret(dtype, value, span=None) -> Any: + """infinity value of dtype + + Parameters + ---------- + dtype : str + The data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The reinterpret cast value of dtype. + """ + return _ffi_api.reinterpret(dtype, value, span) # type: ignore + + def exp(x): """Take exponential of input x. diff --git a/src/script/builder/tir/op.cc b/src/script/builder/tir/op.cc new file mode 100644 index 0000000000000..777ac8a4a4079 --- /dev/null +++ b/src/script/builder/tir/op.cc @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./op.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +PrimExpr prim_type(String type_name, PrimExpr expr) { + return cast(DataType(runtime::String2DLDataType(type_name)), expr); +} + +TVM_REGISTER_GLOBAL("script.builder.tir.PrimType").set_body_typed(prim_type); +TVM_REGISTER_GLOBAL("script.builder.tir.min").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return tvm::min(a, b, span); +}); +TVM_REGISTER_GLOBAL("script.builder.tir.max").set_body_typed([](PrimExpr a, PrimExpr b, Span span) { + return tvm::max(a, b, span); +}); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/op.h b/src/script/builder/tir/op.h new file mode 100644 index 0000000000000..9f7a668f83302 --- /dev/null +++ b/src/script/builder/tir/op.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_OP_H_ +#define TVM_SCRIPT_BUILDER_TIR_OP_H_ + +#include +#include + +#include "../builder.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +PrimExpr int8(PrimExpr expr) { return cast(DataType::Int(8), expr); } +PrimExpr int16(PrimExpr expr) { return cast(DataType::Int(16), expr); } +PrimExpr int32(PrimExpr expr) { return cast(DataType::Int(32), expr); } +PrimExpr int64(PrimExpr expr) { return cast(DataType::Int(64), expr); } + +PrimExpr uint8(PrimExpr expr) { return cast(DataType::UInt(8), expr); } +PrimExpr uint16(PrimExpr expr) { return cast(DataType::UInt(16), expr); } +PrimExpr uint32(PrimExpr expr) { return cast(DataType::UInt(32), expr); } +PrimExpr uint64(PrimExpr expr) { return cast(DataType::UInt(64), expr); } + +PrimExpr float8(PrimExpr expr) { return cast(DataType::Float(8), expr); } +PrimExpr float16(PrimExpr expr) { return cast(DataType::Float(16), expr); } +PrimExpr float32(PrimExpr expr) { return cast(DataType::Float(32), expr); } +PrimExpr float64(PrimExpr expr) { return cast(DataType::Float(64), expr); } + +PrimExpr bool_(PrimExpr expr) { return cast(DataType::Bool(), expr); } + +PrimExpr prim_type(String type_name, PrimExpr expr); + +using tvm::cast; +using tvm::if_then_else; +using tvm::infinity; +using tvm::max; +using tvm::max_value; +using tvm::min; +using tvm::min_value; +using tvm::reinterpret; + +using tvm::ceil; +using tvm::floor; +using tvm::floordiv; +using tvm::floormod; +using tvm::nearbyint; +using tvm::round; +using tvm::trunc; +using tvm::truncdiv; +using tvm::truncmod; + +using tvm::abs; +using tvm::copysign; +using tvm::fmod; +using tvm::nextafter; +using tvm::popcount; + +using tvm::erf; +using tvm::exp; +using tvm::exp10; +using tvm::exp2; +using tvm::hypot; +using tvm::ldexp; +using tvm::log; +using tvm::log10; +using tvm::log1p; +using tvm::log2; +using tvm::pow; +using tvm::rsqrt; +using tvm::sigmoid; +using tvm::sqrt; + +using tvm::acos; +using tvm::acosh; +using tvm::asin; +using tvm::asinh; +using tvm::atan; +using tvm::atan2; +using tvm::atanh; +using tvm::clz; +using tvm::cos; +using tvm::cosh; +using tvm::sin; +using tvm::sinh; +using tvm::tan; +using tvm::tanh; + +using tvm::isfinite; +using tvm::isinf; +using tvm::isnan; + +using tvm::tir::Broadcast; +using tvm::tir::CommReducer; +using tvm::tir::Ramp; +using tvm::tir::Select; +using tvm::tir::Shuffle; + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_OP_H_ diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 696d82be721fa..9ceec7dca7ea7 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -943,6 +943,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity); + +TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret); + // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \