diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 779bcc34272f..3c156dfd7481 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -226,7 +226,7 @@ class CallNode : public ExprNode { /*! * \brief The operator(function) being invoked * - * - It can be relay::Op which corresponds to the primitive operators. + * - It can be tvm::Op which corresponds to the primitive operators. * - It can also be user defined functions (Function, GlobalVar, Var). */ Expr op; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h new file mode 100644 index 000000000000..96526ccfcfb2 --- /dev/null +++ b/include/tvm/tir/builtin.h @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/builtin.h + * \brief TIR builtin intrinsics. + * + * TIR builtin intrinsics are stored as tvm:Op. + * They are processed in the same way as we process Ops. + * + * It is not necessary to create a function for every Op, + * as we can obtain them through Op::Get. + * + * This file contains the most commonly used intrinsics or + * those that have special semantics and need compiler support. + */ +#ifndef TVM_TIR_BUILTIN_H_ +#define TVM_TIR_BUILTIN_H_ + +#include +#include + +namespace tvm { +namespace tir { + +/*! \brief Collection of builtin intrinsics as ops */ +namespace builtin { +/*! + * \brief Reinterpret the value using the target type. + */ +TVM_DLL const Op& reinterpret(); + +/*! + * \brief Marks a condition is likely going to happen. + */ +TVM_DLL const Op& likely(); + +/*! + * \brief Bitwise and operator. + */ +TVM_DLL const Op& bitwise_and(); + +/*! + * \brief Bitwise or operator. + */ +TVM_DLL const Op& bitwise_or(); + +/*! + * \brief Bitwise xor operator. + */ +TVM_DLL const Op& bitwise_xor(); + +/*! + * \brief Bitwise not operator. + */ +TVM_DLL const Op& bitwise_not(); + +/*! + * \brief Left shift + */ +TVM_DLL const Op& shift_left(); + +/*! + * \brief Right shift + */ +TVM_DLL const Op& shift_right(); + +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr large_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +TVM_DLL const Op& large_uint_imm(); + +/*! + * \brief See pesudo code + * + * Handle address_of(Load *op) { + * return &op->buffer_var[index]; + * } + */ +TVM_DLL const Op& address_of(); + +/*! + * \brief Same as select, used for unsafe memory access. + * + * Type tvm_if_then_else(cond, a, b) { + * return cond ? a : b; + * } + */ +TVM_DLL const Op& if_then_else(); + +/*! + * \brief See pesudo code + * + * bool isnullptr(void* handle) { + * return handle == nullptr + * } + */ +TVM_DLL const Op& isnullptr(); + +/*! + * \brief Check if value is nan + */ +TVM_DLL const Op& isnan(); + +/*! + * \brief Popcount + */ +TVM_DLL const Op& popcount(); + +/*! + * \brief Fused multiply add + * + * Type fma(a, b, c) { + * return a * b + c; + * } + */ +TVM_DLL const Op& fma(); + +/*! + * \brief Call an extern C function with given name + * and signature from the types of args in the runtime environment. + * + * Type call_extern(name, args...) { + * return dlsym(name)(args...); + * } + * + * \note This intrinsic does not provide any type checking, + * and is main used for backward compatibility reasons. + * Always consider use pre-registered and typed tvm::Op first. + */ +TVM_DLL const Op& call_extern(); + +/*! + * \brief Call an LLVM intrinsic with a given intrinsic id + * and signature from the types of args in the runtime environment. + * + * Type call_llvm_intrin(intrin_id, args...) { + * return dlsym(name)(args...); + * } + * + * \note This op does not provide any type checking. + */ +TVM_DLL const Op& call_llvm_intrin(); + +/*! + * \brief Call an SPIRV GLSL450 intrinsic. + * + * Type call_spirv_glsl450(intrin_id, args...) { + * return dlsym(name)(args...); + * } + * + * \note This op does not provide any type checking. + */ +TVM_DLL const Op& call_spirv_glsl450(); + +// TODO(tvm-team) revisit the builtins below +// some of them can simply become ops with special codegen attr. +/*! + * \brief Prefetch a cacheline + */ +TVM_DLL const Op& prefetch(); + +/*! + * \brief Get head access address with memory access pattern info. + * + * This operator also marks range of the memory access + * The offset and extent are in unit of the DType(including vectorization factor). + * rw_mask is a bit_mask setting whether the access is a read(1) or write(2). + * The access is assume to happen in the current expression. + * + * PtrType tvm_access_ptr(Expr dtype, DType* data, + * int offset, int extent, + * int rw_mask) { + * // DType == dtype.type(); + * return &data[offset]; + * } + */ +TVM_DLL const Op& tvm_access_ptr(); + +/*! + * \brief Create a function local static handle that iniitalizes to nullptr. + * can be used to cache function local static resources. + */ +TVM_DLL const Op& tvm_static_handle(); + +/*! + * \brief Return a unique context id, used for hint of workspace separation. + * Different context id ganrantees not having overlapping workspace. + */ +TVM_DLL const Op& tvm_context_id(); + +/*! + * \brief tvm_tuple is not an actual function and cannot codegen. + * It is used to represent tuple structure in value field of AttrStmt, + * for the sake of giving hint to optimization. + * + * Handle tvm_tuple(value0, value1, ..., value_n); + */ +TVM_DLL const Op& tvm_tuple(); + +/*! + * \brief See pesudo code + * + * Type tvm_struct_get(StructType* arr, int index, int field_id) { + * return arr[index]->field; + * } + * \sa TVMStructFieldKind + */ +TVM_DLL const Op& tvm_struct_get(); + +/*! + * \brief See pesudo code + * + * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) { + * arr[index]->field = value; + * } + * \sa TVMStructFieldKind + */ +TVM_DLL const Op& tvm_struct_set(); + +/*! + * \brief See pesudo code + * + * void tvm_throw_last_error() { + * throw TVMGetLastError(); + * } + */ +TVM_DLL const Op& tvm_throw_last_error(); + +/*! + * \brief See pesudo code + * + * dtype in {shape, array, arg_value, arg_tcode} + * + * Handle tvm_stack_alloca(string dtype, int num) { + * return new on stack dtype[num]; + * } + */ +TVM_DLL const Op& tvm_stack_alloca(); + +/*! + * \brief Allocate a shape tuple on stack, return the handle. + * + * Handle tvm_stack_make_shape(list args) { + * ret = alloca stack int64_t[len(args)]; + * for i in range(len(args)): + * ret[i] = args[i] + * return &ret[0]; + * } + */ +TVM_DLL const Op& tvm_stack_make_shape(); + +/*! + * \brief Allocate a NDArray(DLTensor) on stack, return the handle. + * + * Type tvm_stack_make_array(Expr data, + * Expr shape, + * Expr strides, + * Expr ndim, + * Expr dtype, + * Expr elem_offset) { + * ret = alloca stack DLTensor(); + * ret->data = data; + * ret->shape = shape; + * ret->strides = strides != 0 ? strides : nullptr; + * ret->ndim = ndim; + * ret->dtype = dtype.type(); + * ret->byte_offset = elem_offset * sizeof(dtype); + * return ret; + * } + */ +TVM_DLL const Op& tvm_stack_make_array(); + +/*! + * \brief See pesudo code + * + * int tvm_call_packed(name, TVMValue* args) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * (*f)(args, type_code_of(args), len(args)); + * return 0; + * } + */ +TVM_DLL const Op& tvm_call_packed(); + +/*! + * \brief See pesudo code + * + * int tvm_call_trace_packed(name, TVMValue* args) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * (*f)(args, type_code_of(args), len(args)); + * return 0; + * } + */ +TVM_DLL const Op& tvm_call_trace_packed(); + +/*! + * \brief See pesudo code + * Mark the content as thread local context, can get optimized + * by only call the call once at thread start. + * + * Do not allow nesting(getting a thread context from another). + * + * Handle tvm_thread_context(Expr call) { + * return call; + * } + */ +TVM_DLL const Op& tvm_thread_context(); + +/*! + * \brief Lowered version of call packed, the space of value and + * type codes are explicitly allocated. + * + * int tvm_call_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * f->CallPacked(TVMArgs(value_stack[begin:end], + * tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +TVM_DLL const Op& tvm_call_packed_lowered(); + +/*! + * \brief Lowered version of trace intrinsic, the space of value and + * type codes are explicitly allocated. The return value is the + * (end - 1) value on the stack. + * + * int tvm_call_trace_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * f->CallPacked(TVMArgs(value_stack[begin:end], + * tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +TVM_DLL const Op& tvm_call_trace_packed_lowered(); + +/*! + * \brief See pseudo code + * + * int tvm_storage_sync(std::string storage_scope) { + * __sync(storage_scope); + * return 0; + * } + */ +TVM_DLL const Op& tvm_storage_sync(); + +/*! + * \brief See pseudo code + * + * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id); + * } + * + * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id - offset); + * } + * + * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id + offset); + * } + * + * unsigned tvm_warp_activemask() { + * return (32-bit mask of currently active threads in the calling warp); + * } + * + * Parameter warp_id indicates the source thread ID in a warp. + * + * Parameter offset indicates the relative distance to this_warp_id. + * + * Parameter width indicates the number of threads involved in one + * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, + * __shfl_down_sync and __activemask. + * + * Parameter warp_size is the size of a warp, which helps a backend + * to determine wheter the width paramter is legal. + * + */ +TVM_DLL const Op& tvm_warp_shuffle(); +TVM_DLL const Op& tvm_warp_shuffle_up(); +TVM_DLL const Op& tvm_warp_shuffle_down(); +TVM_DLL const Op& tvm_warp_activemask(); + +/*! + * \brief Initialize the global barrier. + * Call this at beginning of kernel that need global barrier. + */ +TVM_DLL const Op& tvm_global_barrier_kinit(); + +/*! + * \brief See pesudo code + * + * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, + * Var reduce_temp0, .., Var thread_idx1, ...) { + * // constraint by the other thread_idx remain the same. + * // reduce_temp is used to save intermediate result. + * reduce_temp0, ... = reduce(combiner, source0, ..., cond + * over [thread_idx1, thread_idx2] passed by any caller) + * } + */ +TVM_DLL const Op& tvm_thread_allreduce(); + +// TODO(tvm-team) TensorCore specific intrinsics should be directly registered under +// cuda. namespace and used through op. +/*! + * \brief tvm intrinsic for tensor core load operators. + * + * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment. + * // Determine fragment layout(column-major or row major) by layout. + * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. + * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); + * } + */ +TVM_DLL const Op& tvm_load_matrix_sync(); + +/*! + * \brief tvm intrinsic for tensor core mma_sync operators. + * + * void tvm_mma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +TVM_DLL const Op& tvm_mma_sync(); + +/*! + * \brief tvm intrinsic for tensor core bmma_sync operators. + * + * void tvm_bmma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +TVM_DLL const Op& tvm_bmma_sync(); + +/*! + * \brief tvm intrinsic for tensor core fill_fragment operators. + * + * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr value) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::fill_fragment(fragment[index], value); + * } + */ +TVM_DLL const Op& tvm_fill_fragment(); + +/*! + * \brief tvm intrinsic for tensor core store operators. + * + * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); + * } + */ +TVM_DLL const Op& tvm_store_matrix_sync(); + +// TODO(tvm-team) replace the usage of the vector operations by Shuffle. +/*! + * \brief Get the high level half of the vector + */ +TVM_DLL const Op& vectorhigh(); + +/*! + * \brief Get the low-level half of the vector + */ +TVM_DLL const Op& vectorlow(); + +/*! + * \brief Concat two vectors. + */ +TVM_DLL const Op& vectorcombine(); + +/*! \brief The kind of structure field info used in intrinsic */ +enum TVMStructFieldKind : int { + // array head address + kArrAddr, + kArrData, + kArrShape, + kArrStrides, + kArrNDim, + kArrTypeCode, + kArrTypeBits, + kArrTypeLanes, + kArrByteOffset, + kArrDeviceId, + kArrDeviceType, + kArrKindBound_, + // TVMValue field + kTVMValueContent, + kTVMValueKindBound_ +}; +} // namespace builtin +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_BUILTIN_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 1518d1ff548e..a51f70984011 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -888,8 +888,14 @@ class CallNode : public PrimExprNode { /*! \brief Intrinsic functions that are pure. */ PureIntrinsic = 5 }; - /*! \brief The name of the function/intrinsic. */ - String name; + /*! + * \brief The operator(function) being invoked + * + * - It can be tvm::Op which corresponds to the primitive operators(intrinsics). + * - It can also be another function in the IRModule (GlobalVar). + */ + RelayExpr op; + /*! \brief The arguments. */ Array args; /*! \brief Type of calls. */ @@ -897,19 +903,19 @@ class CallNode : public PrimExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); - v->Visit("name", &name); + v->Visit("op", &op); v->Visit("args", &args); v->Visit("call_type", &call_type); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && + return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) && equal(call_type, other->call_type); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); - hash_reduce(name); + hash_reduce(op); hash_reduce(args); hash_reduce(call_type); } @@ -917,37 +923,8 @@ class CallNode : public PrimExprNode { /*! \return Whether call node is pure. */ bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } - /*! - * \return Whether call node corresponds to a defined intrinsic. - * \param intrin_name The name of the intrinsic. - */ - bool is_intrinsic(const char* intrin_name) const { - return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name); - } - - /*! \return Whether call node can be vectorized. */ - bool is_vectorizable() const; - static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); - - // Build-in intrinsics - static constexpr const char* reinterpret = "reinterpret"; - static constexpr const char* bitwise_and = "bitwise_and"; - static constexpr const char* bitwise_not = "bitwise_not"; - static constexpr const char* bitwise_xor = "bitwise_xor"; - static constexpr const char* bitwise_or = "bitwise_or"; - static constexpr const char* shift_left = "shift_left"; - static constexpr const char* shift_right = "shift_right"; - static constexpr const char* popcount = "popcount"; - static constexpr const char* likely = "likely"; - static constexpr const char* prefetch = "prefetch"; - static constexpr const char* isnan = "isnan"; - static constexpr const char* isfinite = "isfinite"; - static constexpr const char* isinf = "isinf"; - - /*! \brief Vectorizable intrinsic list. */ - static const char* vectorizable_intrinsics[]; }; /*! @@ -958,7 +935,7 @@ class Call : public PrimExpr { public: using CallType = CallNode::CallType; - TVM_DLL Call(DataType dtype, String name, Array args, CallType call_type); + TVM_DLL Call(DataType dtype, RelayExpr op, Array args, CallType call_type); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); }; @@ -1167,358 +1144,6 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } return ret; } - -/*! \brief namespace of TVM Intrinsic functions */ -namespace intrinsic { -/*! - * \brief See pesudo code - * - * Construct a big uint that may not be representable by int64 - * - * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { - * return (v1 << 32) | v0; - * } - */ -constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; -/*! - * \brief See pesudo code - * - * Handle tvm_address_of(Load *op) { - * return &op->buffer_var[index]; - * } - */ -constexpr const char* tvm_address_of = "tvm_address_of"; -/*! - * \brief Same as select, used for unsafe memory access. - * - * Type tvm_if_then_else(cond, a, b) { - * return cond ? a : b; - * } - */ -constexpr const char* tvm_if_then_else = "tvm_if_then_else"; -/*! - * \brief Get head access address with memory access pattern info. - * - * This operator also marks range of the memory access - * The offset and extent are in unit of the DType(including vectorization factor). - * rw_mask is a bit_mask setting whether the access is a read(1) or write(2). - * The access is assume to happen in the current expression. - * - * PtrType tvm_access_ptr(Expr dtype, DType* data, - * int offset, int extent, - * int rw_mask) { - * // DType == dtype.type(); - * return &data[offset]; - * } - */ -constexpr const char* tvm_access_ptr = "tvm_access_ptr"; -/*! - * \brief Create a function local static handle that iniitalizes to nullptr. - * can be used to cache function local static resources. - */ -constexpr const char* tvm_static_handle = "tvm_static_handle"; -/*! - * \brief Return a unique context id, used for hint of workspace separation. - * Different context id ganrantees not having overlapping workspace. - */ -constexpr const char* tvm_context_id = "tvm_context_id"; -/*! - * \brief tvm_tuple is not an actual function and cannot codegen. - * It is used to represent tuple structure in value field of AttrStmt, - * for the sake of giving hint to optimization. - * - * Handle tvm_tuple(value0, value1, ..., value_n); - */ -constexpr const char* tvm_tuple = "tvm_tuple"; -/*! - * \brief See pesudo code - * - * Type tvm_struct_get(StructType* arr, int index, int field_id) { - * return arr[index]->field; - * } - * \sa TVMStructFieldKind - */ -constexpr const char* tvm_struct_get = "tvm_struct_get"; -/*! - * \brief See pesudo code - * - * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) { - * arr[index]->field = value; - * } - * \sa TVMStructFieldKind - */ -constexpr const char* tvm_struct_set = "tvm_struct_set"; -/*! - * \brief See pesudo code - * - * bool tvm_handle_is_null(void* handle) { - * return handle == nullptr - * } - */ -constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; -/*! - * \brief See pesudo code - * - * void tvm_throw_last_error() { - * throw TVMGetLastError(); - * } - */ -constexpr const char* tvm_throw_last_error = "tvm_throw_last_error"; -/*! - * \brief See pesudo code - * - * dtype in {shape, array, arg_value, arg_tcode} - * - * Handle tvm_stack_alloca(string dtype, int num) { - * return new on stack dtype[num]; - * } - */ -constexpr const char* tvm_stack_alloca = "tvm_stack_alloca"; -/*! - * \brief Allocate a shape tuple on stack, return the handle. - * - * Handle tvm_stack_make_shape(list args) { - * ret = alloca stack int64_t[len(args)]; - * for i in range(len(args)): - * ret[i] = args[i] - * return &ret[0]; - * } - */ -constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape"; -/*! - * \brief Allocate a NDArray(DLTensor) on stack, return the handle. - * - * Type tvm_stack_make_array(Expr data, - * Expr shape, - * Expr strides, - * Expr ndim, - * Expr dtype, - * Expr elem_offset) { - * ret = alloca stack DLTensor(); - * ret->data = data; - * ret->shape = shape; - * ret->strides = strides != 0 ? strides : nullptr; - * ret->ndim = ndim; - * ret->dtype = dtype.type(); - * ret->byte_offset = elem_offset * sizeof(dtype); - * return ret; - * } - */ -constexpr const char* tvm_stack_make_array = "tvm_stack_make_array"; -/*! - * \brief See pesudo code - * - * int tvm_call_packed(name, TVMValue* args) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; - * } - */ -constexpr const char* tvm_call_packed = "tvm_call_packed"; -/*! - * \brief See pesudo code - * - * int tvm_call_trace_packed(name, TVMValue* args) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; - * } - */ -constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed"; -/*! - * \brief See pesudo code - * Mark the content as thread local context, can get optimized - * by only call the call once at thread start. - * - * Do not allow nesting(getting a thread context from another). - * - * Handle tvm_thread_context(Expr call) { - * return call; - * } - */ -constexpr const char* tvm_thread_context = "tvm_thread_context"; -/*! - * \brief Lowered version of call packed, the space of value and - * type codes are explicitly allocated. - * - * int tvm_call_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); - * } - */ -constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; -/*! - * \brief Lowered version of trace intrinsic, the space of value and - * type codes are explicitly allocated. The return value is the - * (end - 1) value on the stack. - * - * int tvm_call_trace_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); - * } - */ -constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered"; -/*! - * \brief See pseudo code - * - * int tvm_storage_sync(std::string storage_scope) { - * __sync(storage_scope); - * return 0; - * } - */ -constexpr const char* tvm_storage_sync = "tvm_storage_sync"; - -/*! - * \brief See pseudo code - * - * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id); - * } - * - * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id - offset); - * } - * - * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id + offset); - * } - * - * unsigned tvm_warp_activemask() { - * return (32-bit mask of currently active threads in the calling warp); - * } - * - * Parameter warp_id indicates the source thread ID in a warp. - * - * Parameter offset indicates the relative distance to this_warp_id. - * - * Parameter width indicates the number of threads involved in one - * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, - * __shfl_down_sync and __activemask. - * - * Parameter warp_size is the size of a warp, which helps a backend - * to determine wheter the width paramter is legal. - * - */ -constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; -constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up"; -constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down"; -constexpr const char* tvm_warp_activemask = "tvm_warp_activemask"; - -/*! - * \brief Initialize the global barrier. - * Call this at beginning of kernel that need global barrier. - */ -constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; -/*! - * \brief See pesudo code - * - * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, - * Var reduce_temp0, .., Var thread_idx1, ...) { - * // constraint by the other thread_idx remain the same. - * // reduce_temp is used to save intermediate result. - * reduce_temp0, ... = reduce(combiner, source0, ..., cond - * over [thread_idx1, thread_idx2] passed by any caller) - * } - */ -constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; -/*! - * \brief tvm intrinsic for tensor core load operators. - * - * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment. - * // Determine fragment layout(column-major or row major) by layout. - * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. - * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); - * } - */ -constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; -/*! - * \brief tvm intrinsic for tensor core mma_sync operators. - * - * void tvm_mma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -constexpr const char* tvm_mma_sync = "tvm_mma_sync"; -/*! - * \brief tvm intrinsic for tensor core bmma_sync operators. - * - * void tvm_bmma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -constexpr const char* tvm_bmma_sync = "tvm_bmma_sync"; -/*! - * \brief tvm intrinsic for tensor core fill_fragment operators. - * - * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr value) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::fill_fragment(fragment[index], value); - * } - */ -constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; -/*! - * \brief tvm intrinsic for tensor core store operators. - * - * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); - * } - */ -constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; - -/*! \brief The kind of structure field info used in intrinsic */ -enum TVMStructFieldKind : int { - // array head address - kArrAddr, - kArrData, - kArrShape, - kArrStrides, - kArrNDim, - kArrTypeCode, - kArrTypeBits, - kArrTypeLanes, - kArrByteOffset, - kArrDeviceId, - kArrDeviceType, - kArrKindBound_, - // TVMValue field - kTVMValueContent, - kTVMValueKindBound_ -}; -} // namespace intrinsic - } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 919391e36b96..caddd99eeb2c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -87,8 +87,6 @@ class PrimFuncNode : public BaseFuncNode { * While we could have express parameter unpacking and constraint using * normal statements, making buffer_map as first class citizen of PrimFunc * will make program analysis much easier. - * - * \note This field can be nullptr */ Map buffer_map; @@ -144,7 +142,7 @@ class PrimFunc : public BaseFunc { * \param attrs Additional function attributes. */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = NullValue>(), + Map buffer_map = Map(), DictAttrs attrs = NullValue()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 2948bb2cc20e..286b6d75cb82 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -28,6 +28,7 @@ #ifndef TVM_TIR_OP_H_ #define TVM_TIR_OP_H_ +#include #include #include #include @@ -552,9 +553,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - return tir::Call(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + static const Op& op = Op::Get("tir." #OpName); \ + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \ } TVM_DECLARE_INTRIN_UNARY(exp); diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h new file mode 100644 index 000000000000..d7c13500d90e --- /dev/null +++ b/include/tvm/tir/op_attr_types.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/tir/op_attr_types.h + * \brief Attribute types in the Op registry for TIR ops. + * + * These attributes can be set via OpRegEntry::set_attr + * + * \sa tvm/ir/op.h + */ +#ifndef TVM_TIR_OP_ATTR_TYPES_H_ +#define TVM_TIR_OP_ATTR_TYPES_H_ + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Global symbol of the op after lowering. + */ +using TGlobalSymbol = String; + +/*! + * \brief Whether the op is overloaded for vector form. + */ +using TVectorizable = bool; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_OP_ATTR_TYPES_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index be1c567198d9..b928aec7bcf2 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1004,9 +1004,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \param dtype The data type * \return Expr a expression with dtype. */ -inline PrimExpr TypeAnnotation(DataType dtype) { - return tir::Call(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); -} +TVM_DLL PrimExpr TypeAnnotation(DataType dtype); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index fc8232053b5f..8c3d34af174e 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -98,7 +98,8 @@ def compile_cuda(code, (out, _) = proc.communicate() if proc.returncode != 0: - msg = "Compilation error:\n" + msg = code + msg += "\nCompilation error:\n" msg += py_str(out) raise RuntimeError(msg) diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index e42ac6b37806..f93a943cd9cf 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -18,8 +18,9 @@ import tvm._ffi import tvm.runtime._ffi_api -from tvm.runtime import convert, DataType -from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm +from tvm.runtime import DataType +import tvm.tir +from tvm.tir.expr import Cast as _Cast, FloatImm as _FloatImm def register(type_name, type_code): @@ -135,9 +136,7 @@ def lower(op): if t.lanes > 1: dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): - return _Call(dtype, extern_func_name, convert([op.value]), - _Call.Extern) - return _Call(dtype, extern_func_name, convert([op.a, op.b]), - _Call.Extern) + return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value) + return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b) return lower diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py index acb0efe0ea64..78da8a60d24b 100644 --- a/python/tvm/target/intrin.py +++ b/python/tvm/target/intrin.py @@ -83,10 +83,14 @@ def _rule_float_suffix(op): -------- register_intrin_rule : The registeration function for intrin rule. """ + name = op.name + assert name.startswith("tir.") + prefix = name[4:] + if op.dtype == "float32": - return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) + return call_pure_extern(op.dtype, "%sf" % prefix, *op.args) if op.dtype == "float64": - return call_pure_extern(op.dtype, op.name, *op.args) + return call_pure_extern(op.dtype, prefix, *op.args) return op @@ -111,7 +115,7 @@ def _rule_float_direct(op): register_intrin_rule : The registeration function for intrin rule. """ if str(op.dtype).startswith("float"): - return call_pure_extern(op.dtype, op.name, *op.args) + return call_pure_extern(op.dtype, op.op.name[4:], *op.args) return None # opencl pattern for exp diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index dfbb185a7eb4..a119c20754f4 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -148,7 +148,7 @@ def likely(func_id, args): _internal_assert(args.__len__() == 1, \ "Only one expression can be likely") _internal_assert(func_id == "likely", "This function cannot be directly invoked!") - return call_pure_intrin(args[0].dtype, 'likely', *args) + return call_pure_intrin(args[0].dtype, 'tir.likely', *args) def max_num_threads(func_id, args): diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 3b580efe2b62..386badf3e8aa 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -30,7 +30,7 @@ import tvm._ffi from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const -from tvm.ir import PrimExpr +from tvm.ir import PrimExpr, Op import tvm.ir._ffi_api from . import generic as _generic from . import _ffi_api @@ -144,7 +144,7 @@ def __rxor__(self, other): def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic) + return _ffi_api.bitwise_not(self) def __lt__(self, other): return _ffi_api._OpLT(self, other) @@ -968,8 +968,9 @@ class Call(PrimExprWithOp): dtype : str The return data type - name : str - The name of the function + op : Union[RelayExpr, str] + The function to be called, or the name + to the global tvm.Op args : list of Expr The input arguments to the call @@ -982,9 +983,16 @@ class Call(PrimExprWithOp): PureExtern = 2 Intrinsic = 4 PureIntrinsic = 5 - def __init__(self, dtype, name, args, call_type): - self.__init_handle_by_constructor__( - _ffi_api.Call, dtype, name, args, call_type) + def __init__(self, dtype, op, args, call_type): + if isinstance(op, str): + if not op.startswith("tir."): + raise ValueError( + ("Cannot handle str op argument %s. This function only handles str " + + "argument with the tir namespace. If you are " + + "certain about the intrinsic name, pass in Op.get(name) instead") % op) + op = Op.get(op) + self.__init_handle_by_constructor__( + _ffi_api.Call, dtype, op, args, call_type) @tvm._ffi.register_object("tir.Let") diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 47ba2e2c805c..089127c6f0ff 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -379,7 +379,7 @@ def likely(self, expr): expr : Expr The expression will likely tag. """ - return _expr.Call(expr.dtype, "likely", [expr], + return _expr.Call(expr.dtype, "tir.likely", [expr], _expr.Call.PureIntrinsic) def get(self): diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 929d422ccc43..6826241ac1a6 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -18,10 +18,10 @@ """Operators used in TIR expression.""" import tvm._ffi from tvm.runtime import convert, const -from tvm.ir import Array +from tvm.ir import Array, Op from .buffer import Buffer -from .expr import Call, Var, CommReducer +from .expr import Call, StringImm, Var, CommReducer from . import _ffi_api @@ -29,9 +29,9 @@ def _pack_buffer(buf): """Build intrinsics that packs the buffer. """ assert buf.shape - shape = Call("handle", "tvm_stack_make_shape", buf.shape, + shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, Call.Intrinsic) - strides = Call("handle", "tvm_stack_make_shape", buf.strides, + strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, Call.Intrinsic) if buf.strides else 0 pack_args = [buf.data, shape, @@ -39,7 +39,7 @@ def _pack_buffer(buf): len(buf.shape), const(0, dtype=buf.dtype), buf.elem_offset] - return Call("handle", "tvm_stack_make_array", + return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, Call.Intrinsic) def call_packed(*args): @@ -68,7 +68,7 @@ def call_packed(*args): """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] return Call( - "int32", "tvm_call_packed", call_args, Call.Intrinsic) + "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic) def call_pure_intrin(dtype, func_name, *args): @@ -145,7 +145,7 @@ def call_pure_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.PureExtern) + dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern) def call_extern(dtype, func_name, *args): @@ -168,7 +168,7 @@ def call_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.Extern) + dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern) def call_llvm_intrin(dtype, name, *args): @@ -194,7 +194,8 @@ def call_llvm_intrin(dtype, name, *args): from tvm.target import codegen llvm_id = codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name - return call_pure_intrin(dtype, 'llvm_intrin', tvm.tir.const(llvm_id, 'uint32'), *args) + return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"), + tvm.tir.const(llvm_id, 'uint32'), *args) def any(*args): @@ -278,7 +279,7 @@ def trace(args, trace_action="tvm.default_trace_action"): call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) return tvm.tir.Call( - args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic) + args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic) @@ -327,7 +328,7 @@ def exp(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "exp", x) + return call_pure_intrin(x.dtype, "tir.exp", x) def exp2(x): @@ -343,7 +344,7 @@ def exp2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "exp2", x) + return call_pure_intrin(x.dtype, "tir.exp2", x) def exp10(x): @@ -359,7 +360,7 @@ def exp10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "exp10", x) + return call_pure_intrin(x.dtype, "tir.exp10", x) def erf(x): @@ -375,7 +376,7 @@ def erf(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "erf", x) + return call_pure_intrin(x.dtype, "tir.erf", x) def tanh(x): @@ -391,7 +392,7 @@ def tanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tanh", x) + return call_pure_intrin(x.dtype, "tir.tanh", x) def sigmoid(x): @@ -407,7 +408,7 @@ def sigmoid(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sigmoid", x) + return call_pure_intrin(x.dtype, "tir.sigmoid", x) def log(x): @@ -423,7 +424,7 @@ def log(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log", x) + return call_pure_intrin(x.dtype, "tir.log", x) def log2(x): @@ -439,7 +440,7 @@ def log2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log2", x) + return call_pure_intrin(x.dtype, "tir.log2", x) def log10(x): @@ -455,7 +456,7 @@ def log10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log10", x) + return call_pure_intrin(x.dtype, "tir.log10", x) def log1p(x): @@ -471,7 +472,7 @@ def log1p(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "log1p", x) + return call_pure_intrin(x.dtype, "tir.log1p", x) def tan(x): @@ -487,7 +488,7 @@ def tan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tan", x) + return call_pure_intrin(x.dtype, "tir.tan", x) def cos(x): @@ -503,7 +504,7 @@ def cos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "cos", x) + return call_pure_intrin(x.dtype, "tir.cos", x) def cosh(x): @@ -519,7 +520,7 @@ def cosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "cosh", x) + return call_pure_intrin(x.dtype, "tir.cosh", x) def acos(x): @@ -535,7 +536,7 @@ def acos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "acos", x) + return call_pure_intrin(x.dtype, "tir.acos", x) def acosh(x): @@ -551,7 +552,7 @@ def acosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "acosh", x) + return call_pure_intrin(x.dtype, "tir.acosh", x) def sin(x): @@ -567,7 +568,7 @@ def sin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sin", x) + return call_pure_intrin(x.dtype, "tir.sin", x) def sinh(x): @@ -583,7 +584,7 @@ def sinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sinh", x) + return call_pure_intrin(x.dtype, "tir.sinh", x) def asin(x): @@ -599,7 +600,7 @@ def asin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "asin", x) + return call_pure_intrin(x.dtype, "tir.asin", x) def asinh(x): @@ -615,7 +616,7 @@ def asinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "asinh", x) + return call_pure_intrin(x.dtype, "tir.asinh", x) def atan(x): @@ -631,7 +632,7 @@ def atan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "atan", x) + return call_pure_intrin(x.dtype, "tir.atan", x) def atanh(x): @@ -647,7 +648,7 @@ def atanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "atanh", x) + return call_pure_intrin(x.dtype, "tir.atanh", x) def atan2(x1, x2): @@ -666,7 +667,7 @@ def atan2(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "atan2", x1, x2) + return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2) def sqrt(x): @@ -682,7 +683,7 @@ def sqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "sqrt", x) + return call_pure_intrin(x.dtype, "tir.sqrt", x) def rsqrt(x): @@ -698,7 +699,7 @@ def rsqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "rsqrt", x) + return call_pure_intrin(x.dtype, "tir.rsqrt", x) def floor(x): @@ -823,7 +824,7 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "nextafter", x1, x2) + return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2) def hypot(x1, x2): @@ -842,7 +843,7 @@ def hypot(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "hypot", x1, x2) + return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2) def copysign(x1, x2): @@ -861,7 +862,7 @@ def copysign(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "copysign", x1, x2) + return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2) def ldexp(x1, x2): @@ -880,7 +881,7 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "ldexp", x1, x2) + return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2) def isnan(x): @@ -963,7 +964,7 @@ def popcount(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "popcount", x) + return call_pure_intrin(x.dtype, "tir.popcount", x) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. @@ -980,7 +981,7 @@ def fmod(x, y): z : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "fmod", x, y) + return call_pure_intrin(x.dtype, "tir.fmod", x, y) def if_then_else(cond, t, f): diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index c33990cd1f4f..8c90249f4f17 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include @@ -284,9 +285,10 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const CallNode* op) final { // only special handle >> and & which can be // used for index calculation. - if (op->is_intrinsic(CallNode::shift_right)) { + + if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(tir::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { return Everything(op->dtype); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 84e2093dcf98..c367d0c9f9d8 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -56,8 +56,10 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr real_condition = condition; + static auto op_likely = Op::Get("tir.likely"); + if (auto call = condition.as()) { - if (call->is_intrinsic(CallNode::likely)) { + if (call->op.same_as(op_likely)) { real_condition = call->args[0]; } } @@ -122,7 +124,8 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { + static auto op_if_then_else = Op::Get("tir.if_then_else"); + if (op->op.same_as(op_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr true_value, false_value; { @@ -143,7 +146,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return GetRef(op); } else { - return Call(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); + return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 3457674d4ed3..108f08c4f78f 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include @@ -203,7 +204,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor> which can be // used for index calculation. - if (op->is_intrinsic(CallNode::shift_right)) { + if (op->op.same_as(tir::builtin::shift_right())) { return VisitRightShift(op); } else { return Everything(); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index ff01941e4acf..de8425146bbf 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -66,6 +66,7 @@ #define TVM_ARITH_PATTERN_MATCH_H_ #include +#include #include #include @@ -655,7 +656,7 @@ class PCallExpr : public Pattern> { bool Match_(const ObjectRef& node) const { if (const tir::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; - if (ptr->name != Op::kName) return false; + if (!ptr->op.same_as(Op::GetOp())) return false; detail::PCallExprMatchFunctor fmatch(ptr); detail::tuple_for_each(fmatch, args_); return fmatch.matched_; @@ -675,45 +676,45 @@ class PCallExpr : public Pattern> { }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ } \ - static constexpr const char* kName = IntrinStr; \ + static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ }; \ template \ inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ return PCallExpr(a.derived(), b.derived()); \ } -TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); -TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right"); -TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and"); -TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); -TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); +TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left); +TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right); +TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and); +TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); +TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ + } \ + static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } -TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); +TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::Call(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); + return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); } - static constexpr const char* kName = "tvm_if_then_else"; + static const Op& GetOp() { return tir::builtin::if_then_else(); } }; /*! diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4887ef0ee47d..6758c9b569a8 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -25,6 +25,7 @@ #include "rewrite_simplify.h" #include +#include #include #include @@ -1508,21 +1509,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; - if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) { + + if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) { return op->args[0]; - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(tir::builtin::shift_right())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] >> op->args[1]; } - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(tir::builtin::shift_left())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] & op->args[1]; } } ExprDeepEqual expr_equal; - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(tir::builtin::likely())) { for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (expr_equal(constraint, op->args[0])) { diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index e08f39f8135d..0d5d654c3f6e 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -23,6 +23,7 @@ #include "codegen_hybrid.h" #include +#include #include #include @@ -216,29 +217,43 @@ void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { os << "]"; } void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->is_intrinsic(CallNode::bitwise_and)) { + if (op->op.same_as(builtin::bitwise_and())) { PrintBinaryIntrinsitc(op, "&", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + } else if (op->op.same_as(builtin::bitwise_xor())) { PrintBinaryIntrinsitc(op, "^", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { + } else if (op->op.same_as(builtin::bitwise_or())) { PrintBinaryIntrinsitc(op, "|", os, this); - } else if (op->is_intrinsic(CallNode::shift_left)) { + } else if (op->op.same_as(builtin::shift_left())) { PrintBinaryIntrinsitc(op, "<<", os, this); - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(builtin::shift_right())) { PrintBinaryIntrinsitc(op, ">>", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { + } else if (op->op.same_as(builtin::bitwise_not())) { CHECK_EQ(op->args.size(), 1U); os << "(~"; PrintExpr(op->args[0], os); os << ')'; - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + } else if (op->op.same_as(builtin::if_then_else())) { PrintExpr(op->args[1], os); os << " if "; PrintExpr(op->args[0], os); os << " else "; PrintExpr(op->args[2], os); + } else if (op->op.same_as(builtin::call_extern())) { + StringImm fname = Downcast(op->args[0]); + os << fname << "("; + for (size_t i = 1; i < op->args.size(); i++) { + PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; } else { - os << op->name << "("; + auto* ptr_op = op->op.as(); + CHECK(ptr_op != nullptr); + std::string name = ptr_op->name; + CHECK_EQ(name.compare(0, 4, "tir."), 0); + os << name.substr(4) << "("; for (size_t i = 0; i < op->args.size(); i++) { PrintExpr(op->args[i], os); if (i < op->args.size() - 1) { diff --git a/src/ir/op.cc b/src/ir/op.cc index 63d223050ff5..45c31963695c 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -42,7 +42,7 @@ using OpRegistry = AttrRegistry; // find operator by name const Op& Op::Get(const String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); - CHECK(reg != nullptr) << "Operator " << name << " is not registered"; + CHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered"; return reg->op(); } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 29927379f17d..233a73954c93 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -345,7 +345,14 @@ inline const char* CallType2String(CallNode::CallType t) { Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { Doc doc; - doc << "@" << Doc::Text(op->name) << "("; + if (auto* ptr_op = op->op.as()) { + doc << "@" << Doc::Text(ptr_op->name) << "("; + } else { + // TODO(bohan): Print out the name by he global var in the module. + auto* op_gvar = op->op.as(); + CHECK(op_gvar != nullptr); + doc << "@" << Doc::Text(op_gvar->name_hint) << "("; + } std::vector args; for (const auto& arg : op->args) { args.push_back(Print(arg)); @@ -370,7 +377,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; - doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body); + doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body); return doc; } @@ -389,8 +396,8 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { Doc doc; - doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" - << PrintBody(op->body); + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine() + << Print(op->body); return doc; } diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 35bbb234dbc1..5f5876212b62 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -121,7 +121,7 @@ inline bool IsAtomic(const Expr& e) { * \return compiler_begin op */ inline const Op& CompilerBeginOp() { - static Op op = Op::Get("annotation.compiler_begin"); + static auto op = Op::Get("annotation.compiler_begin"); return op; } @@ -131,7 +131,7 @@ inline const Op& CompilerBeginOp() { * \return compiler_end op */ inline const Op& CompilerEndOp() { - static Op op = Op::Get("annotation.compiler_end"); + static auto op = Op::Get("annotation.compiler_end"); return op; } diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 5a23e83af219..36e553900d00 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -25,6 +25,7 @@ #define TVM_TARGET_INTRIN_RULE_H_ #include +#include #include #include @@ -58,9 +59,20 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - std::string name = T()(call->dtype, call->name); + // Use string based dispatch to extern for backward compact + // TODO(tvm-team) replace once the new dispatching system is inplace. + const OpNode* op = call->op.as(); + CHECK(op != nullptr); + std::string name = op->name; + CHECK_EQ(name.substr(0, 4), "tir."); + name = T()(call->dtype, name.substr(4)); + if (name.length() != 0) { - *rv = Call(call->dtype, name, call->args, CallNode::PureExtern); + Array new_args = {StringImm(name)}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 991d4730a136..13ce59d54b82 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -46,7 +46,7 @@ class CodeGenARM final : public CodeGenCPU { }; llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { - if (op->is_intrinsic("llvm_intrin")) { + if (op->op.same_as(builtin_call_llvm_intrin_)) { llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); @@ -70,7 +70,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::Call(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -94,14 +94,16 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::Call(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = + tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::Call(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = + tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -111,7 +113,8 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::Call(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = + tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -121,7 +124,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::Call(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6ad050ace9a3..f855dd5b83b2 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -226,7 +226,7 @@ std::unique_ptr CodeGenCPU::Finish() { } llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind) { - if (kind < intrinsic::kArrKindBound_) { + if (kind < builtin::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); } else { @@ -234,40 +234,40 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } } switch (kind) { - case intrinsic::kArrAddr: { + case builtin::kArrAddr: { return builder_->CreateInBoundsGEP(buf, index); } - case intrinsic::kArrData: { + case builtin::kArrData: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)}); } - case intrinsic::kArrShape: { + case builtin::kArrShape: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)}); } - case intrinsic::kArrStrides: { + case builtin::kArrStrides: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)}); } - case intrinsic::kArrNDim: { + case builtin::kArrNDim: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); } - case intrinsic::kArrTypeCode: { + case builtin::kArrTypeCode: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); } - case intrinsic::kArrTypeBits: { + case builtin::kArrTypeBits: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); } - case intrinsic::kArrTypeLanes: { + case builtin::kArrTypeLanes: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); } - case intrinsic::kArrByteOffset: { + case builtin::kArrByteOffset: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); } - case intrinsic::kArrDeviceId: { + case builtin::kArrDeviceId: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); } - case intrinsic::kArrDeviceType: { + case builtin::kArrDeviceType: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); } - case intrinsic::kTVMValueContent: { + case builtin::kTVMValueContent: { CHECK_EQ(t.lanes(), 1); CHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { @@ -289,23 +289,23 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } } -llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { - std::vector arg_values(op->args.size()); - for (size_t i = 0; i < op->args.size(); ++i) { - arg_values[i] = MakeValue(op->args[i]); +llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg) { + std::vector arg_values; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + arg_values.push_back(MakeValue(args[i])); } std::vector arg_types; for (llvm::Value* v : arg_values) { arg_types.push_back(v->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_types, false); + llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false); // Check if it is available in global function table as injected function. - auto it = gv_func_map_.find(op->name); + auto it = gv_func_map_.find(global_symbol); if (it != gv_func_map_.end()) { if (it->second == nullptr) { - gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name); - it = gv_func_map_.find(op->name); + gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); + it = gv_func_map_.find(global_symbol); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second)); @@ -314,10 +314,10 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { #endif return builder_->CreateCall(ext_callee, arg_values); } else { - llvm::Function* f = module_->getFunction(op->name); + llvm::Function* f = module_->getFunction(global_symbol); if (f == nullptr) { f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - op->name.operator llvm::StringRef(), module_.get()); + global_symbol.operator llvm::StringRef(), module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -773,38 +773,38 @@ void CodeGenCPU::AddStartupFunction() { } llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { + if (op->op.same_as(builtin::tvm_call_packed_lowered())) { return CreateCallPacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) { + } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { return CreateCallTracePacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) { + } else if (op->op.same_as(builtin::tvm_static_handle())) { return CreateStaticHandle(); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { builder_->CreateRet(ConstInt32(-1)); return ConstInt32(-1); - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { + } else if (op->op.same_as(builtin::tvm_struct_get())) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; llvm::Value* ref = this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - if (kind == intrinsic::kArrAddr) { + if (kind == builtin::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); } else { return builder_->CreateLoad(ref); } - } else if (op->is_intrinsic(intrinsic::tvm_struct_set)) { + } else if (op->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - CHECK(kind != intrinsic::kArrAddr); + CHECK(kind != builtin::kArrAddr); if (value->getType()->isPointerTy()) { value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); } builder_->CreateStore(value, ref); return ConstInt32(0); - } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { + } else if (op->op.same_as(builtin::tvm_stack_alloca())) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 7a14b8fdc959..fdeab4130782 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -47,7 +47,8 @@ class CodeGenCPU : public CodeGenLLVM { void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const ForNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(const CallNode* op) override; + llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg) override; protected: void AddStartupFunction() final; @@ -122,7 +123,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; - std::unordered_map gv_func_map_; + std::unordered_map gv_func_map_; // context for direct dynamic lookup llvm::Function* f_tvm_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 85e3de5844fd..49f14c31d07f 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -653,19 +653,19 @@ llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { return it->second; } -llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { +llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg) { std::vector arg_value; std::vector arg_type; - for (size_t i = 0; i < op->args.size(); ++i) { - arg_value.push_back(MakeValue(op->args[i])); + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + arg_value.push_back(MakeValue(args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); - llvm::Function* f = module_->getFunction(op->name); + llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false); + llvm::Function* f = module_->getFunction(global_symbol); if (f == nullptr) { f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - op->name.operator llvm::StringRef(), module_.get()); + global_symbol.operator llvm::StringRef(), module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; @@ -738,7 +738,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { - if (op->is_intrinsic("llvm_intrin")) { + if (op->op.same_as(builtin_call_llvm_intrin_)) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); int64_t num_signature = Downcast(op->args[1])->value; @@ -759,30 +759,29 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // type as LLVM. llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : llvm::Type::getVoidTy(*ctx_); - llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); return builder_->CreateCall(f, arg_value); - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(builtin::bitwise_and())) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { + } else if (op->op.same_as(builtin::bitwise_or())) { return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { + } else if (op->op.same_as(builtin::bitwise_not())) { return builder_->CreateNot(MakeValue(op->args[0])); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + } else if (op->op.same_as(builtin::bitwise_xor())) { return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::shift_left)) { + } else if (op->op.same_as(builtin::shift_left())) { return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(builtin::shift_right())) { if (op->args[0].dtype().is_int()) { return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } else { return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } - } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { + } else if (op->op.same_as(builtin::tvm_storage_sync())) { return CreateStorageSync(op); - } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { + } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); const RampNode* r = l->index.as(); @@ -797,17 +796,17 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); - } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) { + } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); - } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + } else if (op->op.same_as(builtin::isnullptr())) { return builder_->CreateIsNull(MakeValue(op->args[0])); - } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + } else if (op->op.same_as(builtin::large_uint_imm())) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + } else if (op->op.same_as(builtin::if_then_else())) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); @@ -827,22 +826,22 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(then_value, then_value_block); value->addIncoming(else_value, else_value_block); return value; - } else if (op->is_intrinsic(CallNode::reinterpret)) { + } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); - } else if (op->is_intrinsic(CallNode::isnan)) { + } else if (op->op.same_as(builtin::isnan())) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); return builder_->CreateFCmpUNO(a, a); - } else if (op->is_intrinsic("vectorlow")) { + } else if (op->op.same_as(builtin::vectorlow())) { llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); return CreateVecSlice(v, 0, l / 2); - } else if (op->is_intrinsic("vectorhigh")) { + } else if (op->op.same_as(builtin::vectorhigh())) { llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); return CreateVecSlice(v, l / 2, l / 2); - } else if (op->is_intrinsic("vectorcombine")) { + } else if (op->op.same_as(builtin::vectorcombine())) { llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; @@ -856,7 +855,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } return builder_->CreateShuffleVector(v0, v1, indices); } else { - LOG(FATAL) << "unknown intrinsic " << op->name; + LOG(FATAL) << "unknown intrinsic " << op->op; return nullptr; } } @@ -1076,13 +1075,24 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - return CreateIntrinsic(op); - } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - return CreateCallExtern(op); + if (auto* ptr_op = op->op.as()) { + auto call_op = GetRef(ptr_op); + if (op->op.same_as(builtin_call_extern_)) { + // call extern intrinsic + CHECK_GE(op->args.size(), 1U); + auto global_symbol = Downcast(op->args[0]); + return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, + true); + } else if (op_attr_global_symbol_.count(call_op)) { + // call extern if the op itself have a global symbol. + return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], + op->args, false); + } else { + return CreateIntrinsic(op); + } } else { - LOG(FATAL) << "Unknown call type " - << "name= " << op->name << " call_type= " << op->call_type; + CHECK(op->op.as()); + LOG(FATAL) << "Do not yet support cross function call"; return nullptr; } } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 0bca2a169ba4..2bfe047038b0 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -175,7 +176,9 @@ class CodeGenLLVM : public ExprFunctor, // create intrinstic given call virtual llvm::Value* CreateIntrinsic(const CallNode* op); // create extern function call - virtual llvm::Value* CreateCallExtern(const CallNode* op); + // skip first arg mode used for call extern intrinsic. + virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, + const Array& args, bool skip_first_arg); // Get the corresponding thread index virtual llvm::Value* GetThreadIndex(const IterVar& iv); // Get the corresponding thread index @@ -319,6 +322,11 @@ class CodeGenLLVM : public ExprFunctor, std::unordered_set alias_var_set_; // set of volatile buffer. std::unordered_set volatile_buf_; + // Cache potential common path ops to slightly improve lookup time. + // global symbol table. + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + const Op& builtin_call_extern_ = builtin::call_extern(); + const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { std::unique_ptr di_builder_; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index bc47ce1b1014..71c8e78030c2 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -197,11 +197,11 @@ static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; int offset = 0; - if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { + if (op->op.same_as(builtin::tvm_warp_shuffle())) { offset = 0; - } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { + } else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) { offset = 2; - } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { + } else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) { offset = 4; } else { return false; @@ -226,7 +226,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { llvm::Type* return_type = arg_type[0]; llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); return builder_->CreateCall(func, arg_value); - } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { + } else if (op->op.same_as(builtin::tvm_warp_activemask())) { // Only nvptx target may keep this intrinsic at this point. // PTX assembly: asm "activemask.b32 r1;" auto fty = llvm::FunctionType::get(t_int32_, false); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index edffda287c7b..5d269fa4d513 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -89,7 +89,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), {op->value}, tir::CallNode::PureIntrinsic)), MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), @@ -105,7 +105,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { return CallVectorIntrin( ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), {op->value}, tir::CallNode::PureIntrinsic))}); } #endif diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 8804b1e45a6f..abf350e2208a 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -39,6 +39,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); +// TODO(tvm-team): migrate the legalization transformations as a separate +// set of rules in TIR that can be shared across backends. TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { using tir::make_const; @@ -48,7 +50,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") CHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = tir::Call(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + PrimExpr ret = exp(x * ln10); *rv = ret; }); @@ -97,8 +99,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_two = make_const(x.dtype(), -2); - PrimExpr exp_neg2x = tir::Call(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = tir::Call(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_neg2x = exp(neg_two * x); + PrimExpr exp_pos2x = exp(two * x); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); @@ -116,9 +118,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr sin_x = tir::Call(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); - PrimExpr cos_x = tir::Call(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); - PrimExpr tan_x = sin_x / cos_x; + PrimExpr tan_x = sin(x) / cos(x); *rv = tan_x; }); @@ -135,8 +135,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_negx = exp(neg_one * x); + PrimExpr exp_posx = exp(x); PrimExpr ret = (exp_posx + exp_negx) / two; *rv = ret; }); @@ -154,8 +154,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_negx = exp(neg_one * x); + PrimExpr exp_posx = exp(x); PrimExpr ret = (exp_posx - exp_negx) / two; *rv = ret; }); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index 5613621d77fb..cc9437d25b7e 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -49,7 +50,8 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); + *rv = + tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic); } template @@ -64,7 +66,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 49c2224932a5..a0ffe11da27a 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -23,7 +23,9 @@ #ifdef TVM_LLVM_VERSION #include +#include #include +#include #include @@ -36,10 +38,21 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { const CallNode* call = e.as(); CHECK(call != nullptr); CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; + + const OpNode* op = call->op.as(); + CHECK(op != nullptr); + std::string name = op->name; + CHECK_EQ(name.substr(0, 4), "tir."); + std::ostringstream intrinsic_name; - intrinsic_name << "__nv_" << call->name; + intrinsic_name << "__nv_" << name.substr(4); if (call->dtype.bits() == 32) intrinsic_name << "f"; - *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); + + Array new_args = {StringImm(intrinsic_name.str())}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); } namespace llvm { diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 3a2b8ac77f82..07520ae08cc8 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -23,6 +23,7 @@ #ifdef TVM_LLVM_VERSION #include +#include #include #include @@ -36,9 +37,21 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { using namespace tir; const CallNode* call = e.as(); CHECK(call != nullptr); + + const OpNode* op = call->op.as(); + CHECK(op != nullptr); + std::string name = op->name; + CHECK_EQ(name.substr(0, 4), "tir."); + std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); - *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); + intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); + + Array new_args = {StringImm(intrinsic_name.str())}; + for (auto arg : call->args) { + new_args.push_back(arg); + } + + *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); } inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { @@ -53,29 +66,30 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { // get own lane in self (__lane_id) PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); PrimExpr zero = tir::make_zero(DataType::Int(32)); - PrimExpr lo = - Call(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, CallNode::PureExtern); - PrimExpr self = - Call(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, CallNode::PureExtern); + PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(), + {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern); + PrimExpr self = Call(DataType::Int(32), builtin::call_extern(), + {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern); // compute lane to get from PrimExpr width = call->args[3]; PrimExpr index; - if (call->name == "tvm_warp_shuffle") { + if (call->op.same_as(builtin::tvm_warp_shuffle())) { PrimExpr src_lane = call->args[2]; index = src_lane + (self & ~(width - 1)); - } else if (call->name == "tvm_warp_shuffle_up") { + } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) { PrimExpr delta = call->args[2]; index = self - delta; index = Select(index < (self & ~(width - 1)), self, index); } else { - CHECK_EQ(call->name, "tvm_warp_shuffle_down"); + CHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); PrimExpr delta = call->args[2]; index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } PrimExpr res = - Call(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, CallNode::PureExtern); + Call(var.dtype(), builtin::call_extern(), + {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern); *rv = res; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 9255d7c80c46..ffeaba06d701 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -223,12 +223,12 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i // Print a reference expression to a buffer. std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { - if (kind < intrinsic::kArrKindBound_) { + if (kind < builtin::kArrKindBound_) { std::ostringstream os; os << "(((DLTensor*)"; this->PrintExpr(buffer, os); os << ")"; - if (kind == intrinsic::kArrAddr) { + if (kind == builtin::kArrAddr) { os << " + "; this->PrintExpr(index, os); os << ")"; @@ -239,34 +239,34 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << "]."; // other case: get fields. switch (kind) { - case intrinsic::kArrData: + case builtin::kArrData: os << "data"; break; - case intrinsic::kArrShape: + case builtin::kArrShape: os << "shape"; break; - case intrinsic::kArrStrides: + case builtin::kArrStrides: os << "strides"; break; - case intrinsic::kArrNDim: + case builtin::kArrNDim: os << "ndim"; break; - case intrinsic::kArrTypeCode: + case builtin::kArrTypeCode: os << "dtype.code"; break; - case intrinsic::kArrTypeBits: + case builtin::kArrTypeBits: os << "dtype.bits"; break; - case intrinsic::kArrByteOffset: + case builtin::kArrByteOffset: os << "byte_offset"; break; - case intrinsic::kArrTypeLanes: + case builtin::kArrTypeLanes: os << "dtype.lanes"; break; - case intrinsic::kArrDeviceId: + case builtin::kArrDeviceId: os << "ctx.device_id"; break; - case intrinsic::kArrDeviceType: + case builtin::kArrDeviceType: os << "ctx.device_type"; break; default: @@ -275,7 +275,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ')'; return os.str(); } else { - CHECK_LT(kind, intrinsic::kTVMValueKindBound_); + CHECK_LT(kind, builtin::kTVMValueKindBound_); std::ostringstream os; os << "(((TVMValue*)"; this->PrintExpr(buffer, os); @@ -559,80 +559,94 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->a, os); } -void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - os << op->name << "("; - for (size_t i = 0; i < op->args.size(); i++) { - this->PrintExpr(op->args[i], os); - if (i < op->args.size() - 1) { - os << ", "; - } +void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) { // NOLINT(*) + os << global_symbol << "("; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + this->PrintExpr(args[i], os); + if (i < args.size() - 1) { + os << ", "; } - os << ")"; - } else if (op->is_intrinsic(CallNode::bitwise_and)) { - PrintBinaryIntrinsic(op, " & ", os, this); - } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { - CHECK_EQ(op->args.size(), 2U); - uint64_t low = static_cast(Downcast(op->args[0])->value); - uint64_t high = static_cast(Downcast(op->args[1])->value); - uint64_t val = (high << 32U) | low; - PrintUIntConst(op->dtype, val, os, this); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { - PrintBinaryIntrinsic(op, " ^ ", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { - PrintBinaryIntrinsic(op, " | ", os, this); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { - CHECK_EQ(op->args.size(), 1U); - os << "(~"; - this->PrintExpr(op->args[0], os); - os << ')'; - } else if (op->is_intrinsic(CallNode::shift_left)) { - PrintBinaryIntrinsic(op, " << ", os, this); - } else if (op->is_intrinsic(CallNode::shift_right)) { - PrintBinaryIntrinsic(op, " >> ", os, this); - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - os << "("; - PrintExpr(op->args[0], os); - os << " ? "; - PrintExpr(op->args[1], os); - os << " : "; - PrintExpr(op->args[2], os); - os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode* l = op->args[0].as(); - CHECK(op->args.size() == 1 && l); - os << "(("; - this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; - this->PrintExpr(l->index, os); - os << ')'; - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { - CHECK_EQ(op->args.size(), 3U); - os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); - } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { - CHECK_EQ(op->args.size(), 1U); - os << "("; - this->PrintExpr(op->args[0], os); - os << " == NULL)"; - } else if (op->is_intrinsic(CallNode::reinterpret)) { - // generate (*( TYPE *)(&(ARG))) - os << "(*("; - this->PrintType(op->dtype, os); - os << " *)(&("; - this->PrintExpr(op->args[0], os); - os << ")))"; - } else if (op->is_intrinsic(CallNode::isnan)) { - os << "("; - this->PrintExpr(op->args[0], os); - os << " != "; - this->PrintExpr(op->args[0], os); - os << ")"; - } else { - if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + } + os << ")"; +} + +void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + if (auto* ptr_op = op->op.as()) { + auto call_op = GetRef(ptr_op); + + if (op->op.same_as(builtin_call_extern_)) { + CHECK_GE(op->args.size(), 1U); + auto func = Downcast(op->args[0]); + this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); + } else if (op_attr_global_symbol_.count(call_op)) { + // call extern if the op itself have a global symbol. + this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], + op->args, false, os); + } else if (op->op.same_as(builtin::bitwise_and())) { + PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->op.same_as(builtin::large_uint_imm())) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); + } else if (op->op.same_as(builtin::bitwise_xor())) { + PrintBinaryIntrinsic(op, " ^ ", os, this); + } else if (op->op.same_as(builtin::bitwise_or())) { + PrintBinaryIntrinsic(op, " | ", os, this); + } else if (op->op.same_as(builtin::bitwise_not())) { + CHECK_EQ(op->args.size(), 1U); + os << "(~"; + this->PrintExpr(op->args[0], os); + os << ')'; + } else if (op->op.same_as(builtin::shift_left())) { + PrintBinaryIntrinsic(op, " << ", os, this); + } else if (op->op.same_as(builtin::shift_right())) { + PrintBinaryIntrinsic(op, " >> ", os, this); + } else if (op->op.same_as(builtin::if_then_else())) { + os << "("; + PrintExpr(op->args[0], os); + os << " ? "; + PrintExpr(op->args[1], os); + os << " : "; + PrintExpr(op->args[2], os); + os << ")"; + } else if (op->op.same_as(builtin::address_of())) { + const LoadNode* l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + os << "(("; + this->PrintType(l->dtype.element_of(), os); + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; + this->PrintExpr(l->index, os); + os << ')'; + } else if (op->op.same_as(builtin::tvm_struct_get())) { + CHECK_EQ(op->args.size(), 3U); + os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); + } else if (op->op.same_as(builtin::isnullptr())) { + CHECK_EQ(op->args.size(), 1U); + os << "("; + this->PrintExpr(op->args[0], os); + os << " == NULL)"; + } else if (op->op.same_as(builtin::reinterpret())) { + // generate (*( TYPE *)(&(ARG))) + os << "(*("; + this->PrintType(op->dtype, os); + os << " *)(&("; + this->PrintExpr(op->args[0], os); + os << ")))"; + } else if (op->op.same_as(builtin::isnan())) { + os << "("; + this->PrintExpr(op->args[0], os); + os << " != "; + this->PrintExpr(op->args[0], os); + os << ")"; } else { - LOG(FATAL) << "Unresolved call type " << op->call_type; + LOG(FATAL) << "Unresolved call " << op->op; } + } else { + CHECK(op->op.as()); + LOG(FATAL) << "Do not yet support cross function call"; } } @@ -903,10 +917,10 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); if (call) { - if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { + if (call->op.same_as(builtin::tvm_storage_sync())) { this->PrintStorageSync(call); return; - } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) { + } else if (call->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 309eb0681607..9346f87cb3bb 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,10 +24,13 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ +#include #include #include +#include #include #include +#include #include #include @@ -219,6 +222,16 @@ class CodeGenC : public ExprFunctor, */ virtual bool IsScopePartOfType() const { return true; } + /*! + * \brief Print external function call. + * \param ret_type The return type. + * \param global_symbol The symbolc of the target function. + * \param args The arguments to the function. + * \param skip_first_arg Whether to skip the first arguments. + * \param os The output stream. + */ + virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os); // NOLINT(*) /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. @@ -245,6 +258,10 @@ class CodeGenC : public ExprFunctor, std::unordered_map alloc_storage_scope_; /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; + /*! \brief Record of ops that have pre-defined global symbol. */ + OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); + // cache commonly used ops + const Op& builtin_call_extern_ = builtin::call_extern(); private: /*! \brief whether to print in SSA form */ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index b11b3d8fc5f9..839962a8c733 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -175,7 +175,7 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar } void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { + if (op->op.same_as(builtin::tvm_stack_alloca())) { std::string stack_name = GetUniqueName("stack"); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); @@ -197,7 +197,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT this->PrintIndent(); this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; os << stack_name; - } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { const StringImmNode* s = op->args[0].as(); CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; @@ -216,7 +216,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } this->PrintGetFuncFromBackend(func_name, packed_func_name); this->PrintFuncCall(packed_func_name, num_args); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; } else { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index cf7a74f1dcc0..ae5e40acd8f5 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -429,15 +429,71 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } +void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + if (ret_dtype.is_vector()) { + // + // Emit an unsupported vector call + // + // v = intrin_f((float4*)A[0], (float4*)B[0]) + // + // as + // + // float4 __ret; + // { + // float4 __arg0 = ((float4*)A)[0]; + // float4 __arg1 = ((float4*)B)[0]; + // __ret.x = intrin_f(__arg0.x, __arg1.x); + // __ret.y = intrin_f(__arg0.y, __arg1.y); + // __ret.z = intrin_f(__arg0.z, __arg1.z); + // __ret.w = intrin_f(__arg0.w, __arg1.w); + // } + // v = __ret; + // + // Declare the result vector. + std::string sret = GetUniqueName("_"); + this->PrintIndent(); + this->PrintType(ret_dtype, stream); + stream << ' ' << sret << ";\n"; + { + // Load arguments. + std::vector sargs; + size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + sargs.push_back(std::move(val)); + } + + // Emit a scalar call for each lane. + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j > 0) scall << ", "; + PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + } + scall << ")"; + PrintVecElemStore(sret, ret_dtype, i, scall.str()); + } + } + os << sret; + } else { + CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); + } +} + void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { - // This is only for backward compatibility with __shfl_{up/down}. - // A macro will be used to replace *_sync calls to legacy ones. - if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") || - op->is_intrinsic("__shfl_down_sync")) { - enable_warp_shuffle_ = true; + if (auto* ptr_op = op->op.as()) { + Op call_op = GetRef(ptr_op); + // This is only for backward compatibility with __shfl_{up/down}. + // A macro will be used to replace *_sync calls to legacy ones. + if (op_need_warp_shuffle_.get(call_op, false)) { + enable_warp_shuffle_ = true; + } } - if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + if (op->op.same_as(builtin::tvm_fill_fragment())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; @@ -447,7 +503,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "], "; this->PrintExpr(op->args[5], os); os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { + } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; @@ -459,7 +515,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << ", "; this->PrintExpr(op->args[6], os); os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; @@ -476,7 +532,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Invalid parameters"; } os << ")"; - } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { + } else if (op->op.same_as(builtin::tvm_mma_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::mma_sync("; @@ -486,7 +542,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + } else if (op->op.same_as(builtin::tvm_bmma_sync())) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::bmma_sync("; @@ -496,51 +552,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) { - // - // Emit an unsupported vector call - // - // v = intrin_f((float4*)A[0], (float4*)B[0]) - // - // as - // - // float4 __ret; - // { - // float4 __arg0 = ((float4*)A)[0]; - // float4 __arg1 = ((float4*)B)[0]; - // __ret.x = intrin_f(__arg0.x, __arg1.x); - // __ret.y = intrin_f(__arg0.y, __arg1.y); - // __ret.z = intrin_f(__arg0.z, __arg1.z); - // __ret.w = intrin_f(__arg0.w, __arg1.w); - // } - // v = __ret; - // - // Declare the result vector. - std::string sret = GetUniqueName("_"); - this->PrintIndent(); - this->PrintType(op->dtype, stream); - stream << ' ' << sret << ";\n"; - { - // Load arguments. - std::vector sargs; - for (size_t i = 0; i < op->args.size(); ++i) { - std::string val = SSAGetID(PrintExpr(op->args[i]), op->args[i].dtype()); - sargs.push_back(std::move(val)); - } - - // Emit a scalar call for each lane. - for (int i = 0; i < op->dtype.lanes(); ++i) { - std::ostringstream scall; - scall << op->name << "("; - for (size_t j = 0; j < op->args.size(); ++j) { - if (j > 0) scall << ", "; - PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall); - } - scall << ")"; - PrintVecElemStore(sret, op->dtype, i, scall.str()); - } - } - os << sret; } else { CodeGenC::VisitExpr_(op, os); } @@ -600,7 +611,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); - if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { + if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { PrintIndent(); stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; PrintIndent(); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index f9ab0ade2cf2..3cde8e379eb4 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -68,6 +69,10 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; + protected: + void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) final; // NOLINT(*) + private: // Handle volatile loads void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; @@ -91,6 +96,8 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + // Op attribute map + OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 2c26ee977639..1c4256c5a166 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -270,7 +270,7 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->is_intrinsic(CallNode::reinterpret)) { + if (op->op.same_as(builtin::reinterpret())) { // generate as_type(ARG) os << "(as_type<"; this->PrintType(op->dtype, os); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 45746b8ef721..53a2799e2725 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -21,6 +21,9 @@ * \file intrin_rule_cuda.cc * \brief CUDA intrinsic rules. */ +#include +#include + #include "../intrin_rule.h" namespace tvm { @@ -93,23 +96,23 @@ struct CUDAPopcount { }; struct CUDAWarpIntrinsic { - const char* operator()(DataType t, const std::string& name) const { - if (name == intrinsic::tvm_warp_shuffle) { - return "__shfl_sync"; - } - if (name == intrinsic::tvm_warp_shuffle_up) { - return "__shfl_up_sync"; - } - if (name == intrinsic::tvm_warp_shuffle_down) { - return "__shfl_down_sync"; - } - if (name == intrinsic::tvm_warp_activemask) { - return "__activemask"; + const Op operator()(DataType t, const Op& orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.cuda.__shfl_sync"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.cuda.__shfl_up_sync"); + } else { + CHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.cuda.__shfl_down_sync"); } - return ""; } }; +static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) { + Call call = args[0]; + *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern); +} + template static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; @@ -117,8 +120,9 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - const char* name = T()(call->dtype, call->name); - *rv = Call(call->dtype, name, cuda_args, CallNode::PureExtern); + + *rv = + Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); @@ -175,10 +179,32 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") .set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") - .set_body(DispatchExtern); + .set_body(DispatchCUDAWarpActiveMask); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); +// Register low-level builtin ops. +// TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. +TVM_REGISTER_OP("tir.cuda.__shfl_sync") + .set_num_inputs(4) + .set_attr("TGlobalSymbol", "__shfl_sync") + .set_attr("cuda.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.cuda.__shfl_up_sync") + .set_num_inputs(4) + .set_attr("TGlobalSymbol", "__shfl_up_sync") + .set_attr("cuda.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.cuda.__shfl_down_sync") + .set_num_inputs(4) + .set_attr("TGlobalSymbol", "__shfl_down_sync") + .set_attr("cuda.need_warp_shuffle", true); + +TVM_REGISTER_OP("tir.cuda.__activemask") + .set_num_inputs(0) + .set_attr("TGlobalSymbol", "__activemask") + .set_attr("cuda.need_warp_shuffle", true); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 8453b33f8a43..82eabdd96dfe 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -79,8 +79,8 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { arith::Analyzer analyzer; CHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; - Array opencl_args{{call->args[1], call->args[2]}}; - *rv = Call(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); + Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; + *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 699d3953f04c..6c12343c81ec 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -24,6 +24,7 @@ #include "codegen_spirv.h" #include +#include #include #include @@ -236,7 +237,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { - if (op->is_intrinsic("spirv_glsl450")) { + if (op->op.same_as(builtin::call_spirv_glsl450())) { CHECK_GE(op->args.size(), 2U); uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; @@ -244,31 +245,31 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { values.push_back(MakeValue(op->args[i])); } return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + } else if (op->op.same_as(builtin::bitwise_and())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + } else if (op->op.same_as(builtin::bitwise_xor())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::bitwise_or)) { + } else if (op->op.same_as(builtin::bitwise_or())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::bitwise_not)) { + } else if (op->op.same_as(builtin::bitwise_not())) { CHECK_EQ(op->args.size(), 1U); spirv::Value a = MakeValue(op->args[0]); return builder_->MakeValue(spv::OpNot, a.stype, a); - } else if (op->is_intrinsic(CallNode::shift_left)) { + } else if (op->op.same_as(builtin::shift_left())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b); - } else if (op->is_intrinsic(CallNode::shift_right)) { + } else if (op->op.same_as(builtin::shift_right())) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); @@ -277,18 +278,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } - } else if (op->is_intrinsic(CallNode::reinterpret)) { + } else if (op->op.same_as(builtin::reinterpret())) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); - } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + } else if (op->op.same_as(builtin::large_uint_imm())) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return builder_->UIntImm(builder_->GetSType(op->dtype), val); - } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { + } else if (op->op.same_as(builtin::tvm_storage_sync())) { return this->CreateStorageSync(op); - } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + } else if (op->op.same_as(builtin::if_then_else())) { CHECK_EQ(op->args.size(), 3U); spirv::Value cond = MakeValue(op->args[0]); spirv::Label then_label = builder_->NewLabel(); @@ -312,14 +313,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(0, then_value, then_value_label); phi.SetIncoming(1, else_value, else_value_label); return phi; - } else if (op->is_intrinsic("popcount")) { + } else if (op->op.same_as(builtin::popcount())) { return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else { if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype; } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype; + LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index a6b254770daa..1b9d2e4e410d 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include namespace tvm { @@ -43,7 +44,8 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs, + tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 6dd2ca0ecb6c..84b14925877a 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -41,31 +42,31 @@ using namespace tir; // map struct field kind to runtime variants // We keep two separate enums to ensure runtime/compiler isolation. StackVM::StructFieldKind MapFieldKind(int64_t kind) { - auto val = static_cast(kind); + auto val = static_cast(kind); switch (val) { - case intrinsic::kArrData: + case builtin::kArrData: return StackVM::kArrData; - case intrinsic::kArrShape: + case builtin::kArrShape: return StackVM::kArrShape; - case intrinsic::kArrAddr: + case builtin::kArrAddr: return StackVM::kArrAddr; - case intrinsic::kArrStrides: + case builtin::kArrStrides: return StackVM::kArrStrides; - case intrinsic::kArrNDim: + case builtin::kArrNDim: return StackVM::kArrNDim; - case intrinsic::kArrTypeCode: + case builtin::kArrTypeCode: return StackVM::kArrTypeCode; - case intrinsic::kArrTypeBits: + case builtin::kArrTypeBits: return StackVM::kArrTypeBits; - case intrinsic::kArrTypeLanes: + case builtin::kArrTypeLanes: return StackVM::kArrTypeLanes; - case intrinsic::kArrByteOffset: + case builtin::kArrByteOffset: return StackVM::kArrByteOffset; - case intrinsic::kArrDeviceId: + case builtin::kArrDeviceId: return StackVM::kArrDeviceId; - case intrinsic::kArrDeviceType: + case builtin::kArrDeviceType: return StackVM::kArrDeviceType; - case intrinsic::kTVMValueContent: + case builtin::kTVMValueContent: return StackVM::kTVMValueContent; default: LOG(FATAL) << "Do not know how to map field " << kind; @@ -174,7 +175,7 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { } void CodeGenStackVM::VisitExpr_(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_address_of)) { + if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); @@ -182,9 +183,9 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); - } else if (op->is_intrinsic(CallNode::reinterpret)) { + } else if (op->op.same_as(builtin::reinterpret())) { this->Push(op->args[0]); - } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { + } else if (op->op.same_as(builtin::tvm_struct_get())) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; this->Push(op->args[0]); @@ -197,7 +198,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = MapFieldKind(kind); vm_.code.push_back(code); - } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { + } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { CHECK_GE(op->args.size(), 5U); const StringImmNode* s = op->args[0].as(); CHECK(s != nullptr) << "tvm_call_global expect first argument as function name"; @@ -226,7 +227,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = end; vm_.code.push_back(code); - } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { + } else if (op->op.same_as(builtin::tvm_stack_alloca())) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); @@ -249,7 +250,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { // add stack size to be safe. vm_.stack_size += size; this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast(size)); - } else if (op->name == "TVMBackendAllocWorkspace") { + } else if (op->op.same_as(backend_alloc_workspace_op_)) { CHECK_EQ(op->args.size(), 5U); this->Push(op->args[0]); this->Push(op->args[1]); @@ -257,21 +258,21 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->Push(op->args[3]); this->Push(op->args[4]); this->PushOp(StackVM::TVM_DEVICE_ALLOCA); - } else if (op->name == "TVMBackendFreeWorkspace") { + } else if (op->op.same_as(backend_free_workspace_op_)) { CHECK_EQ(op->args.size(), 3U); this->Push(op->args[0]); this->Push(op->args[1]); this->Push(op->args[2]); this->PushOp(StackVM::TVM_DEVICE_FREE); - } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { + } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PushOp(StackVM::TVM_THROW_LAST_ERROR); - } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + } else if (op->op.same_as(builtin::isnullptr())) { CHECK_EQ(op->args.size(), 1U); this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_HANDLE); } else { - LOG(FATAL) << "unknown function call " << op->name; + LOG(FATAL) << "unknown function call " << op->op; } } @@ -430,7 +431,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { if (is_const(ev->value)) return; const CallNode* op = ev->value.as(); - if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { + if (op && op->op.same_as(builtin::tvm_struct_set())) { CHECK_EQ(op->args.size(), 4U); this->Push(op->args[0]); this->Push(op->args[3]); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index b77c40696de6..480ffc7eb870 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -153,6 +154,9 @@ class CodeGenStackVM : public ExprFunctor, std::unordered_map str_idmap_; /*! \brief id of each global function */ std::unordered_map extern_fun_idmap_; + + Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace"); + Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace"); }; } // namespace codegen diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 1834aa3decf7..f6254121b7cb 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -96,31 +96,30 @@ class JacobianMutator : public ExprMutator { PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); if (op->call_type == CallNode::CallType::PureIntrinsic) { - static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; - if (op->name == "exp") { + if (op->op.same_as(op_exp_)) { return Mul(Mutate(op->args[0]), expr); - } else if (op->name == "log") { + } else if (op->op.same_as(op_log_)) { return Div(Mutate(op->args[0]), op->args[0]); - } else if (op->name == "sigmoid") { + } else if (op->op.same_as(op_sigmoid_)) { return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); - } else if (op->name == "sqrt") { + } else if (op->op.same_as(op_sqrt_)) { return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); - } else if (op->name == "tanh") { + } else if (op->op.same_as(op_tanh_)) { return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); - } else if (op->name == "pow") { + } else if (op->op.same_as(op_pow_)) { auto x = op->args[0], y = op->args[1]; return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); - } else if (op->name == "fabs") { + } else if (op->op.same_as(op_fabs_)) { auto type = op->args[0].dtype(); return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0), FloatImm(type, -1.0))); - } else if (op->name == intrinsic::tvm_if_then_else) { + } else if (op->op.same_as(op_if_then_else_)) { Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; - return Call(op->dtype, op->name, new_args, op->call_type); - } else if (piecewise_const.count(op->name)) { + return Call(op->dtype, op->op, new_args, op->call_type); + } else if (piecewise_const.count(op->op)) { return FloatImm(expr.dtype(), 0.0); } else { - throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); + LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op; } } NOT_IMPLEMENTED; @@ -281,6 +280,17 @@ class JacobianMutator : public ExprMutator { Array indices_; Var input_var_; arith::Analyzer analyzer_; + + const Op& op_exp_ = Op::Get("tir.exp"); + const Op& op_log_ = Op::Get("tir.log"); + const Op& op_sigmoid_ = Op::Get("tir.sigmoid"); + const Op& op_sqrt_ = Op::Get("tir.sqrt"); + const Op& op_tanh_ = Op::Get("tir.tanh"); + const Op& op_pow_ = Op::Get("tir.pow"); + const Op& op_fabs_ = Op::Get("tir.fabs"); + const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); + std::unordered_set piecewise_const = { + Op::Get("tir.floor"), Op::Get("tir.ceil"), Op::Get("tir.trunc"), Op::Get("tir.round")}; }; PrimExpr Derivative(const PrimExpr& expr, const Var& var) { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 1fc0520143fb..b4725c571782 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -278,7 +279,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, attr->dim_align_offset}; realize = tir::AttrStmt( t, tir::attr::buffer_dim_align, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), realize); } } diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index e834ff279d05..eeaab301ad03 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -21,6 +21,8 @@ * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ +#include + #include "compute_op.h" #include "op_util.h" @@ -194,7 +196,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // Apply the existing input predicate if any. output_preds.push_back(input_pred); - Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, + Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), freduce_args, CallNode::Intrinsic)); reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index ef55c44241b0..01019e43e61c 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -153,7 +153,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, tuple.push_back(buffer->shape[k]); } ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); + Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 8d5265bcb14f..714e8859229d 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -153,7 +154,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, } input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // output binding @@ -177,7 +178,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // Check variable remap diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 82832c927785..dd978a430e4b 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -370,7 +370,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, } input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -390,7 +390,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 1ff569f29f1f..67121b881a33 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -42,7 +43,6 @@ namespace tvm { namespace te { using namespace te; -using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; @@ -255,9 +255,9 @@ class BodyVisitor : public StmtExprVisitor { } } - void VisitExpr_(const CallNode* op) final { + void VisitExpr_(const ProducerLoadNode* op) final { StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->name, op->args)); + args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); } friend class ScheduleAnalyser; @@ -415,7 +415,7 @@ class BufferAnalyser : public StmtExprVisitor { } else if (op->attr_key == tir::attr::buffer_dim_align) { te::Tensor tensor = Downcast(op->node); const CallNode* tuple = op->value.as(); - CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); auto& vinfo = dim_align_[tensor]; size_t dim = tuple->args[0].as()->value; if (dim >= vinfo.size()) { @@ -848,13 +848,13 @@ class TensorCoreIRMutator : public StmtExprMutator { Buffer buffer_b(buffer_node_b); if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { return Evaluate( - Call(DataType::Handle(), intrinsic::tvm_bmma_sync, + Call(DataType::Handle(), builtin::tvm_bmma_sync(), {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); } else { return Evaluate( - Call(DataType::Handle(), intrinsic::tvm_mma_sync, + Call(DataType::Handle(), builtin::tvm_mma_sync(), {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); @@ -879,7 +879,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto pload = dst.as(); auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, + return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, op->value}, CallNode::Intrinsic)); @@ -889,11 +889,11 @@ class TensorCoreIRMutator : public StmtExprMutator { return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); } - const CallNode* value = op->value.as(); + const ProducerLoadNode* value = op->value.as(); CHECK(value != nullptr) << "Can only load fragment from a buffer"; - auto it = strides_.find(value->name); - CHECK(it != strides_.end()) << "Cannot find stride for " << value->name; + auto it = strides_.find(value->producer->GetNameHint()); + CHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); auto strides = it->second; CHECK_GE(strides.size(), 2); PrimExpr stride = strides[strides.size() - 2]; @@ -902,7 +902,9 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - PrimExpr src = Call(value->dtype, "&", {mutated_value}, CallNode::Extern); + // TODO(tvm-team) The extern function name seems to be a hack. + PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}, + CallNode::Extern); auto pload = dst.as(); PrimExpr matrix_major; @@ -918,7 +920,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, src, stride, matrix_major}, CallNode::Intrinsic)); @@ -941,12 +943,13 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = Call(DataType::Handle(), "&", {dst}, CallNode::Extern); + dst = + Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern); auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, dst, stride, StringImm("col_major")}, CallNode::Intrinsic)); @@ -1064,7 +1067,7 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(pload->indices[i]); args.push_back(shape[i]); } - auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); + auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic); Array node = {buffer, tensor}; return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 8eb846b7d618..12ec270a53cc 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -120,7 +121,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; const CallNode* C = iter->second.as(); - if (!C || C->name != intrinsic::tvm_struct_get) return false; + if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false; V = C->args[0].as(); } return false; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 4e433fc718b1..6cccfa0fcebf 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -376,7 +377,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); + return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic); } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 9390feada456..4b20351e2053 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -698,50 +698,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Call -Call::Call(DataType dtype, String name, Array args, CallType call_type) { +Call::Call(DataType dtype, RelayExpr op, Array args, CallType call_type) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } ObjectPtr node = make_object(); node->dtype = dtype; - node->name = std::move(name); + node->op = std::move(op); node->args = std::move(args); node->call_type = call_type; data_ = std::move(node); } -const char* CallNode::vectorizable_intrinsics[] = {"floor", - "ceil", - "sign", - "trunc", - "fabs", - "round", - "exp", - "tanh", - "sqrt", - "log", - "sin", - "cos", - "pow", - "tan", - tir::CallNode::shift_left, - tir::CallNode::shift_right, - tir::CallNode::likely, - tir::CallNode::popcount}; - -bool CallNode::is_vectorizable() const { - size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); - for (size_t i = 0; i < cnt; ++i) { - if (name == CallNode::vectorizable_intrinsics[i]) { - return true; - } - } - return false; -} - TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, String name, Array args, int call_type) { + .set_body_typed([](DataType type, RelayExpr op, Array args, int call_type) { Array prim_expr_args; for (const auto& it : args) { CHECK(it->IsInstance() || it->IsInstance()); @@ -751,7 +722,7 @@ TVM_REGISTER_GLOBAL("tir.Call") prim_expr_args.push_back(Downcast(it)); } } - return Call(type, name, prim_expr_args, static_cast(call_type)); + return Call(type, op, prim_expr_args, static_cast(call_type)); }); TVM_REGISTER_NODE_TYPE(CallNode); @@ -759,7 +730,13 @@ TVM_REGISTER_NODE_TYPE(CallNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << op->name << "("; + if (auto* ptr_op = op->op.as()) { + p->stream << ptr_op->name << "("; + } else { + auto* ptr_gvar = op->op.as(); + CHECK(ptr_gvar != nullptr); + p->stream << "@" << ptr_gvar->name_hint << "("; + } for (size_t i = 0; i < op->args.size(); ++i) { p->Print(op->args[i]); if (i < op->args.size() - 1) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index b92127b24e2b..98b9fd02c09c 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -166,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return Call(op->dtype, op->name, args, op->call_type); + return Call(op->dtype, op->op, args, op->call_type); } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 66497755c88a..c3ddb6625d53 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -582,5 +582,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->PrintIndent(); p->stream << "}\n"; }); + +PrimExpr TypeAnnotation(DataType dtype) { + static auto op = Op::Get("tir.type_annotation"); + return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic); +} + +TVM_REGISTER_OP("tir.type_annotation"); + } // namespace tir } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc new file mode 100644 index 000000000000..8efcf3ff4925 --- /dev/null +++ b/src/tir/op/builtin.cc @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/builtin.cc + * + * builtin intrinsic operators. + */ +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace builtin { + +#define TIR_DEFINE_BUILTIN_FUNC(OpName) \ + const Op& OpName() { \ + static const Op& op = Op::Get("tir." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tir." #OpName) + +TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_and) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_or) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_xor) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(bitwise_not) + .set_num_inputs(1) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(shift_left) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(shift_right) + .set_num_inputs(2) + .set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2); + +TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3); + +TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr("TVectorizable", true); + +TIR_DEFINE_BUILTIN_FUNC(call_extern); + +TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin); + +TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450); + +TIR_DEFINE_BUILTIN_FUNC(prefetch); + +TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5); + +TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(tvm_tuple); + +TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3); + +TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4); + +TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0); + +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2); + +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape); + +TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6); + +// When num_inputs are not set, the function is assumed to be variable length. +TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed); + +TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed); + +TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered); + +TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered); + +// TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure. +TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down); + +TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask); + +TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit); + +TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce); + +TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync); + +TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment); + +TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync); + +TIR_DEFINE_BUILTIN_FUNC(vectorhigh); + +TIR_DEFINE_BUILTIN_FUNC(vectorlow); + +TIR_DEFINE_BUILTIN_FUNC(vectorcombine); + +} // namespace builtin +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/op.cc b/src/tir/op/op.cc similarity index 82% rename from src/tir/ir/op.cc rename to src/tir/op/op.cc index 5ac9f5902c12..f8049eace356 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/op/op.cc @@ -18,12 +18,16 @@ */ /*! - * \file expr_operator.cc + * \file tir/op/op.cc + * + * Common operator definitions for ops in tir/op.h */ #include +#include #include #include +#include #include // Centralized header for constant folders. @@ -33,6 +37,12 @@ namespace tvm { using namespace tir; +// macro to register an unary op +#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1) + +// macro to register an binary op +#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2) + runtime::DataType GetRuntimeDataType(const Type& type) { if (auto* n = type.as()) { return n->dtype; @@ -70,8 +80,9 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return tir::Cast(t, value); } +// LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { - return tir::Call(t, tir::intrinsic::tvm_large_uint_imm, + return tir::Call(t, tir::builtin::large_uint_imm(), {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, tir::CallNode::PureIntrinsic); } @@ -248,11 +259,13 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } } +// reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::Call(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic); } +// operator+ PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); @@ -360,6 +373,7 @@ PrimExpr max(PrimExpr a, PrimExpr b) { return tir::Max(a, b); } +// if_then_else PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; @@ -371,15 +385,20 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return tir::Call(true_value.dtype(), tir::intrinsic::tvm_if_then_else, + + return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } +// likely PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::Call(cond.dtype(), tir::CallNode::likely, {cond}, tir::CallNode::PureIntrinsic); + return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic); } +TVM_REGISTER_OP("tir.likely").set_num_inputs(1); + +// operator> PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); @@ -445,6 +464,7 @@ PrimExpr operator!(PrimExpr a) { return tir::Not(a); } +// shirt right PrimExpr operator>>(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -460,9 +480,11 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::CallNode::shift_right, {a, b}, tir::CallNode::PureIntrinsic); + + return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic); } +// shift left PrimExpr operator<<(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -478,9 +500,10 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::CallNode::shift_left, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwise and PrimExpr operator&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -489,9 +512,10 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); - return tir::Call(a.dtype(), tir::CallNode::bitwise_and, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwise_or PrimExpr operator|(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -500,9 +524,10 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); - return tir::Call(a.dtype(), tir::CallNode::bitwise_or, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwise_xor PrimExpr operator^(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -511,20 +536,30 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); - return tir::Call(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic); } +// bitwie_not PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::Call(a.dtype(), tir::CallNode::bitwise_not, {a}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic); } +TVM_REGISTER_OP("tir.bitwise_not"); + +TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; }); + +// pow PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return tir::Call(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.pow"); + return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); } +TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr("TVectorizable", true); + +// abs PrimExpr abs(PrimExpr x) { if (x.dtype().is_int()) { using tir::IntImmNode; @@ -539,7 +574,8 @@ PrimExpr abs(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } - return tir::Call(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.fabs"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { @@ -549,6 +585,9 @@ PrimExpr abs(PrimExpr x) { } } +TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr("TVectorizable", true); + +// isnan PrimExpr isnan(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -559,12 +598,12 @@ PrimExpr isnan(PrimExpr x) { if (fx) { return make_const(t, std::isnan(fx->value)); } + static auto op = Op::Get("tir.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, tir::CallNode::isnan, - {cast(DataType::Float(32, t.lanes()), std::move(x))}, + return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))}, tir::CallNode::PureIntrinsic); } else { - return tir::Call(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -572,6 +611,9 @@ PrimExpr isnan(PrimExpr x) { } } +TIR_REGISTER_PURE_UNARY_OP("tir.isnan"); + +// isinf PrimExpr isinf(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -585,6 +627,7 @@ PrimExpr isinf(PrimExpr x) { } } +// isfinite PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); } PrimExpr sum(PrimExpr source, Array rdom) { @@ -637,12 +680,17 @@ PrimExpr prod(PrimExpr source, Array rdom) { return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } +// fmod PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return tir::Call(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.fmod"); + return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.fmod"); + +// floor PrimExpr floor(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -650,9 +698,13 @@ PrimExpr floor(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); - return tir::Call(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.floor"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr("TVectorizable", true); + +// ceil PrimExpr ceil(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -660,9 +712,13 @@ PrimExpr ceil(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); - return tir::Call(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.ceil"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr("TVectorizable", true); + +// round PrimExpr round(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -670,9 +726,13 @@ PrimExpr round(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::Call(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.round"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr("TVectorizable", true); + +// nearbyint PrimExpr nearbyint(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -680,9 +740,13 @@ PrimExpr nearbyint(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::Call(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.nearbyint"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); + +// trunc PrimExpr trunc(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -692,9 +756,72 @@ PrimExpr trunc(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return tir::Call(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); + static auto op = Op::Get("tir.trunc"); + return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); } +TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr("TVectorizable", true); + +// unary op registration. +TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.erf"); + +TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid"); + +TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt"); + +TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.log1p"); + +TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr("TVectorizable", true); + +TIR_REGISTER_PURE_UNARY_OP("tir.asin"); + +TIR_REGISTER_PURE_UNARY_OP("tir.acos"); + +TIR_REGISTER_PURE_UNARY_OP("tir.atan"); + +TIR_REGISTER_PURE_UNARY_OP("tir.acosh"); + +TIR_REGISTER_PURE_UNARY_OP("tir.asinh"); + +TIR_REGISTER_PURE_UNARY_OP("tir.atanh"); + +// binary intrinsics +TIR_REGISTER_PURE_BINARY_OP("tir.atan2"); + +TIR_REGISTER_PURE_BINARY_OP("tir.nextafter"); + +TIR_REGISTER_PURE_BINARY_OP("tir.hypot"); + +TIR_REGISTER_PURE_BINARY_OP("tir.copysign"); + +TIR_REGISTER_PURE_BINARY_OP("tir.ldexp"); + // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { @@ -783,4 +910,5 @@ TVM_REGISTER_GLOBAL("tir._OpIfThenElse") .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { return if_then_else(cond, true_value, false_value); }); + } // namespace tvm diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc new file mode 100644 index 000000000000..1c540e3a650a --- /dev/null +++ b/src/tir/op/runtime.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/op/runtime.cc + * \brief TIR ops for runtime functions. + */ +#include +#include + +namespace tvm { +namespace tir { + +TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace") + .set_num_inputs(5) + .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace"); + +TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace") + .set_num_inputs(3) + .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace"); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index ae7065d94d80..80c526827ad5 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -24,6 +24,7 @@ #include "arg_binder.h" #include +#include #include #include @@ -141,7 +142,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st } } -inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) { +inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kind) { return TVMStructGet(t, arr, 0, kind); } @@ -152,7 +153,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); // dimension checks - PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); @@ -162,11 +163,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, DataType dtype = buffer->dtype; std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; - PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == + PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == IntImm(DataType::UInt(8), dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == + TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == IntImm(DataType::UInt(8), dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == + TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); @@ -174,7 +175,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); } // data field - if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), + if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), arg_name + ".data", true)) { Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); @@ -186,7 +187,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( - LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); + LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; @@ -202,9 +203,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back( - LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); + LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); PrimExpr is_null = - Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); + Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -262,12 +263,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, if (const auto* const_offset = buffer->elem_offset.as()) { Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), + TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, cast(buffer->elem_offset.dtype(), - (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) / + (TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset) / make_const(DataType::UInt(64), data_bytes))), arg_name + ".elem_offset", true)) { if (buffer->offset_factor > 1) { @@ -280,9 +281,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } } // device info. - Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), + Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), arg_name + ".device_type", true); - Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), + Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 445ac1cf60cd..9722d1100a7e 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -188,13 +189,13 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. - return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); + return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic); } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); + auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 94464a04f912..3b6af0644fc9 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -66,7 +67,7 @@ class BoundChecker : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { + if (process_store_ && op->op.same_as(builtin::if_then_else())) { unsafe_rewritten_ = true; } return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 73bf4c6f6db2..0485bb1f7613 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,7 @@ namespace tir { class ContextCallCombiner final : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_thread_context)) { + if (op->op.same_as(builtin::tvm_thread_context())) { CHECK_EQ(op->args.size(), 1U); PrimExpr ctx = op->args[0]; auto it = ctx_map_.find(ctx); @@ -48,13 +49,7 @@ class ContextCallCombiner final : public StmtExprMutator { return it->second; } else { CHECK(ctx.dtype().is_handle()); - std::string name; - if (const CallNode* call = ctx.as()) { - name = call->name + "_cache"; - } else { - name = "ctx_cache_"; - } - Var ctx_var(name, ctx.dtype()); + Var ctx_var("ctx_cache_", ctx.dtype()); ctx_map_[ctx] = ctx_var; return std::move(ctx_var); } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 384dbcb0caee..092a7cdeca98 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -21,6 +21,7 @@ * \file coproc_sync.cc */ #include +#include #include #include #include @@ -54,7 +55,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode* buffer = op->args[1].as(); if (in_scope_) { touched_[buffer].coproc = true; @@ -195,7 +196,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; + return { + Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -208,8 +210,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { explicit CoProcBarrierDetector(const std::unordered_set& touched, const std::string& coproc_name) : touched_(touched) { - read_barrier_name_ = coproc_name + ".coproc_read_barrier"; - write_barrier_name_ = coproc_name + ".coproc_write_barrier"; + read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier"; + write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier"; } void PlanReadBarrier(const Stmt& stmt) { @@ -331,7 +333,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; - return Evaluate(Call(DataType::Int(32), func, + return Evaluate(Call(DataType::Int(32), Op::Get(func), {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); } @@ -346,8 +348,8 @@ class CoProcInstDepDetector : public StmtVisitor { public: explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { - sync_push_name_ = coproc_name + ".coproc_dep_push"; - sync_pop_name_ = coproc_name + ".coproc_dep_pop"; + sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push"); + sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop"); } void Plan(const Stmt& stmt) { @@ -555,12 +557,12 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_push_name_, + return Evaluate(Call(DataType::Int(32), sync_push_op_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_pop_name_, + return Evaluate(Call(DataType::Int(32), sync_pop_op_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, CallNode::Intrinsic)); } @@ -568,7 +570,7 @@ class CoProcInstDepDetector : public StmtVisitor { SyncState first_state_, last_state_, curr_state_; // Variables IterVar coproc_axis_; - std::string sync_push_name_, sync_pop_name_; + Op sync_push_op_, sync_pop_op_; }; class CoProcSyncInserter : public StmtMutator { diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 042ddab15a2f..7180dd29d903 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -21,6 +21,7 @@ * \file inject_virtual_thread.cc */ #include +#include #include #include #include @@ -54,7 +55,7 @@ class ExprTouched final : public StmtExprVisitor { } void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); CHECK(buffer_var); @@ -219,7 +220,7 @@ class VTInjector : public StmtExprMutator { } // Expression. PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -230,9 +231,9 @@ class VTInjector : public StmtExprMutator { PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return Call(op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]}, + return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); - } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { + } else if (op->op.same_as(builtin::tvm_context_id())) { return allow_share_ ? GetRef(op) : var_; } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 6c0eeea97278..758923b15af9 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -25,6 +25,7 @@ #define TVM_TIR_TRANSFORMS_IR_UTIL_H_ #include +#include #include #include @@ -83,10 +84,10 @@ inline Array UpdateArray(Array arr, F fupdate) { * \return the get expression. */ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, - intrinsic::TVMStructFieldKind kind) { + builtin::TVMStructFieldKind kind) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; - return Call(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); + return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic); } /*! @@ -96,7 +97,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return Call(DataType::Handle(), intrinsic::tvm_address_of, + return Call(DataType::Handle(), builtin::address_of(), {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), const_true(dtype.lanes()))}, CallNode::PureIntrinsic); @@ -113,7 +114,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return Call(DataType::Handle(), intrinsic::tvm_address_of, + return Call(DataType::Handle(), builtin::address_of(), {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } @@ -125,11 +126,10 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \param value The value to be set. * \return the set stmt. */ -inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind, - PrimExpr value) { +inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; - return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic)); } /*! diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 3b2580c60074..2fb8003486f1 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -140,11 +141,11 @@ class CandidateSelector final : public StmtExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(builtin::likely())) { in_likely_ = true; StmtExprVisitor::VisitExpr_(op); in_likely_ = false; - } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { + } else if (op->op.same_as(builtin::tvm_thread_allreduce())) { // no split if the body contains allreduce. no_split_ = true; return; @@ -214,7 +215,7 @@ class PartitionFinder : public StmtExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(builtin::likely())) { PrimExpr cond = op->args[0]; if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is @@ -596,7 +597,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b class RemoveLikelyTags : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(CallNode::likely)) { + if (op->op.same_as(builtin::likely())) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); } else { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 9d6b47a1ca37..fac50a08a9b7 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -79,7 +80,7 @@ class StorageAccessInfoLower : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { return MakeAccessPtr(op); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index c7aa949924d7..d38cb7b36042 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -51,9 +51,17 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { + // NOTE: call_type will eventually be deprecated and the information + // will be folded into Op's attr if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - PrimExpr r = ApplyPattern(op->name, GetRef(op)); - if (r.defined()) return r; + if (auto* ptr_op = op->op.as()) { + // Still use legacy string based rewriting + // TODO(tvm-team): migrate the pattern application from global function look up + // to an OpAttrMap + std::string name = ptr_op->name; + PrimExpr r = ApplyPattern(name, GetRef(op)); + if (r.defined()) return r; + } } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -230,7 +238,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(Call(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -241,7 +249,11 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) { + PrimExpr ApplyPattern(std::string name, const PrimExpr& e) { + if (name.compare(0, 4, "tir.") == 0) { + name = name.substr(4); + } + for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; size_t psize = p.length(); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index ee17f081c6d8..dab8d5a78d02 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -71,7 +72,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); const CallNode* call = op->value.as(); - if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { + if (call && call->op.same_as(builtin::tvm_thread_allreduce())) { return MakeAllreduce(call); } else { return stmt; @@ -242,7 +243,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { { PrimExpr pred = const_true(1); PrimExpr mask = - Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. @@ -273,8 +274,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The former may cause dead lock as there is a divergent // branch with a warp sync call inside. // - const char* shfl_func = intrinsic::tvm_warp_shuffle_down; - PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); + PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset); const AllocateNode* repl = local_vars[i].as(); Stmt s = Store(repl->buffer_var, other, index, pred); seq.push_back(s); @@ -303,9 +303,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - const char* shfl_func = intrinsic::tvm_warp_shuffle; PrimExpr val = Load(types[i], var, index, pred); - PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); + PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0); seq.push_back(Store(var, splat, index, pred)); } @@ -465,18 +464,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)}, + return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}, CallNode::Intrinsic)); } - // Emit warp shuffle intrinsic calls. - PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { + // Emit warp shuffle calls. + PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) { PrimExpr pred = const_true(1); PrimExpr index(0); PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; - return Call(val.dtype(), name, args, CallNode::Intrinsic); + return Call(val.dtype(), op, args, CallNode::Intrinsic); } // Check if this is a reduction on threadIdx.x and its extent matches diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 7611e0fcc8b3..e6182301a335 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -22,6 +22,7 @@ * \file tir/transforms/lower_tvm_buildin.cc */ #include +#include #include #include #include @@ -40,7 +41,7 @@ inline PrimExpr ConstInt32(size_t index) { inline PrimExpr StackAlloca(std::string type, size_t num) { Array args = {StringImm(type), ConstInt32(num)}; - return Call(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); + return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -103,23 +104,22 @@ class BuiltinLower : public StmtExprMutator { CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = - Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); + Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), + Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}, + CallNode::PureIntrinsic), throw_last_error), op->body}); - Stmt alloca = LetStmt( op->buffer_var, - Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", + Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); - PrimExpr free_op = Call(DataType::Int(32), "TVMBackendFreeWorkspace", + PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), op->buffer_var}, CallNode::Extern); @@ -144,15 +144,15 @@ class BuiltinLower : public StmtExprMutator { } } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_call_packed)) { + if (op->op.same_as(builtin::tvm_call_packed())) { return MakeCallPacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) { + } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { return MakeCallTracePacked(op); - } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) { + } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { return MakeShape(op); - } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) { + } else if (op->op.same_as(builtin::tvm_stack_make_array())) { return MakeArray(op); - } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { + } else if (op->op.same_as(builtin::tvm_context_id())) { return make_zero(op->dtype); } else { return StmtExprMutator::VisitExpr_(op); @@ -176,21 +176,21 @@ class BuiltinLower : public StmtExprMutator { run_array_stack_ += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrData, op->args[0])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrStrides, strides)); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, + TVMStructSet(stack_array_, idx, builtin::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeBits, make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeLanes, make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); @@ -198,15 +198,15 @@ class BuiltinLower : public StmtExprMutator { if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceId, cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceType, cast(DataType::Int(32), device_type_))); - return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); + return TVMStructGet(DataType::Handle(), stack_array_, idx, builtin::kArrAddr); } // call packed. PrimExpr MakeCallPacked(const CallNode* op) { @@ -226,7 +226,7 @@ class BuiltinLower : public StmtExprMutator { arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; @@ -245,7 +245,7 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], stack_value_, stack_tcode_, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return Call(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, + return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args, CallNode::Intrinsic); } @@ -267,7 +267,7 @@ class BuiltinLower : public StmtExprMutator { arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + builtin::kTVMValueContent, arg)); int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( @@ -287,7 +287,7 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. op->args[args_size - 1]}; - return Call(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, + return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args, CallNode::Intrinsic); } @@ -295,8 +295,8 @@ class BuiltinLower : public StmtExprMutator { bool IsArrayHandle(const PrimExpr& arg) { // specially set array handle. if (const CallNode* buf = arg.as()) { - if (buf->is_intrinsic(intrinsic::tvm_struct_get) && - buf->args[2].as()->value == intrinsic::kArrAddr) { + if (buf->op.same_as(builtin::tvm_struct_get()) && + buf->args[2].as()->value == builtin::kArrAddr) { return true; } } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 92f9ab54adb4..3e7d13b2ff6e 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -250,8 +251,8 @@ class WarpAccessRewriter : protected StmtExprMutator { << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); PrimExpr mask = - Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); - return Call(load_value.dtype(), intrinsic::tvm_warp_shuffle, + Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); + return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index a91e350e6b22..9bb5fc6b5971 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -82,10 +83,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; + IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = Call(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); + PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { res = Cast(t, res); @@ -189,7 +190,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { Stmt set_device = - Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, CallNode::Intrinsic)); body = SeqStmt({set_device, body}); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 07b0ea29a52a..a14fd02e7700 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include @@ -318,6 +319,8 @@ class DataTypeRewriter : public StmtExprMutator { std::unordered_map ivmap_; // indicator of LoadNode::index and StoreNode::index bool is_index_{false}; + // cached ops + const Op& builtin_pow_ = Op::Get("tir.pow"); }; #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ @@ -352,23 +355,23 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { op = e.as(); CHECK(op != nullptr) << "Expected type to be CallNode" << ", but get " << e->GetTypeKey(); - if (op->call_type == CallNode::PureIntrinsic) { - if (op->name == intrinsic::tvm_if_then_else) { - return if_then_else(op->args[0], op->args[1], op->args[2]); - } else if (op->name == CallNode::shift_right) { - return op->args[0] >> op->args[1]; - } else if (op->name == CallNode::shift_left) { - return op->args[0] << op->args[1]; - } else if (op->name == CallNode::bitwise_and) { - return op->args[0] & op->args[1]; - } else if (op->name == CallNode::bitwise_or) { - return op->args[0] | op->args[1]; - } else if (op->name == CallNode::bitwise_xor) { - return op->args[0] ^ op->args[1]; - } else if (op->name == "pow") { - return pow(op->args[0], op->args[1]); - } + + if (op->op.same_as(builtin::if_then_else())) { + return if_then_else(op->args[0], op->args[1], op->args[2]); + } else if (op->op.same_as(builtin::shift_right())) { + return op->args[0] >> op->args[1]; + } else if (op->op.same_as(builtin::shift_left())) { + return op->args[0] << op->args[1]; + } else if (op->op.same_as(builtin::bitwise_and())) { + return op->args[0] & op->args[1]; + } else if (op->op.same_as(builtin::bitwise_or())) { + return op->args[0] | op->args[1]; + } else if (op->op.same_as(builtin::bitwise_xor())) { + return op->args[0] ^ op->args[1]; + } else if (op->op.same_as(builtin_pow_)) { + return pow(op->args[0], op->args[1]); } + return e; } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index 701f0cea1bfa..e5535369c39e 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -22,6 +22,7 @@ * \brief Rewrite uinsafe select expression. */ #include +#include #include #include #include @@ -37,9 +38,9 @@ class UnsafeExprDetector : public ExprFunctor { // Because we will issue guard to make sure it is. bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + if (op->op.same_as(builtin::if_then_else())) { return VisitExpr(op->args[0]); - } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { + } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); return this->VisitExpr(l->index); } else if (op->is_pure()) { @@ -104,7 +105,7 @@ class UnsafeSelectRewriter : public StmtExprMutator { bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return Call(op->dtype, intrinsic::tvm_if_then_else, + return Call(op->dtype, builtin::if_then_else(), {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 0684189c88e8..c35caf54db4a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -238,7 +239,7 @@ class HostDeviceSplitter : public StmtMutator { call_args.push_back(ext); } return Evaluate( - Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); + Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic)); } // target ir module diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 20cc6402135f..24f8b756974c 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -181,10 +181,10 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { } void StorageAccessVisitor::VisitExpr_(const CallNode* op) { - if (op->is_intrinsic(intrinsic::tvm_address_of)) { + if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); StmtExprVisitor::VisitExpr_(l); - } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + } else if (op->op.same_as(builtin::tvm_access_ptr())) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -211,7 +211,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } StmtExprVisitor::VisitExpr_(op); - } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { + } else if (op->op.same_as(builtin::tvm_storage_sync())) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index e29d978e0d42..30805508144d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -45,7 +46,6 @@ namespace tvm { namespace tir { -using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; @@ -101,7 +101,7 @@ class StorageFlattener : public StmtExprMutator { } else if (op->attr_key == attr::buffer_dim_align) { auto buffer = Downcast(op->node); const CallNode* tuple = op->value.as(); - CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); auto& vinfo = dim_align_[buffer]; int dim = tuple->args[0].as()->value; if (static_cast(dim) >= vinfo.size()) { @@ -322,9 +322,9 @@ class StorageFlattener : public StmtExprMutator { } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = - Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); + Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic); PrimExpr prefetch = - Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic); stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -392,7 +392,7 @@ class StorageFlattener : public StmtExprMutator { const BufferNode* target = arr[1].as(); const CallNode* tuple = op->value.as(); CHECK(buffer && target); - CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); + CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); auto key = GetRef(target); auto it = buf_map_.find(key); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 283ab0f6f703..d7a258cffe30 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -131,7 +132,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { } } void VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_address_of)) { + if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); this->VisitExpr(l->index); } else { @@ -387,7 +388,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -403,7 +404,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return Call(op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, op->call_type); } else { return StmtExprMutator::VisitExpr_(op); @@ -911,7 +912,7 @@ class VectorAllocRewriter : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); UpdateTypeMap(buffer, dtype); diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 493aa516fbd7..1b3b3c44ff9c 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -53,8 +53,8 @@ class FragmentGetter : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); - if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || - op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { + if (op->op.same_as(builtin::tvm_load_matrix_sync()) || + op->op.same_as(builtin::tvm_store_matrix_sync())) { // Get shape and layout information from load and store intrinsic CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var = op->args[0].as(); @@ -89,7 +89,7 @@ class FragmentGetter : public StmtExprVisitor { } fragments[buffer_var] = info; } - } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { + } else if (op->op.same_as(builtin::tvm_fill_fragment())) { // Get shape information from fill intrinsic CHECK_EQ(op->args.size(), 6U); const VarNode* buffer_var = op->args[0].as(); @@ -141,7 +141,7 @@ class FragmentChecker : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) { CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 612efb092395..cdd9377e00d6 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -209,7 +210,7 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. @@ -259,7 +260,7 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + if (op->op.same_as(builtin::tvm_access_ptr())) { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); @@ -299,7 +300,7 @@ class ThreadSyncInserter : public StmtExprMutator { CHECK(op != nullptr); Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = - Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); + Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; @@ -309,7 +310,7 @@ class ThreadSyncInserter : public StmtExprMutator { } rw_stats_.clear(); Stmt kinit = Evaluate( - Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); + Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); body = AttrStmt(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); @@ -332,7 +333,7 @@ class ThreadSyncInserter : public StmtExprMutator { } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } - return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, CallNode::Intrinsic)); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 227aea2eb575..1a2ec502f605 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -23,8 +23,10 @@ // Loop vectorizer as in Halide pipeline. #include #include +#include #include #include +#include #include #include @@ -212,15 +214,18 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); + return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type); } } // Call PrimExpr VisitExpr_(const CallNode* op) final { - if (op->name == intrinsic::tvm_if_then_else) { + if (op->op.same_as(builtin::if_then_else())) { return MutateIfThenElseExpr_(op); } - if (!op->is_vectorizable()) { + auto* op_ptr = op->op.as(); + bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); + + if (!vectorizable) { // Cannot vectorize this op Array new_args; for (auto arg : op->args) { @@ -234,7 +239,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype, op->name, new_args, op->call_type); + return Call(op->dtype, op->op, new_args, op->call_type); } } else { int lane = 0; @@ -243,7 +248,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); + return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type); } } } @@ -380,6 +385,9 @@ class Vectorizer : public StmtExprMutator { bool need_scalarize_{false}; // The lets std::unordered_map lets_; + // vectorizable property + OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); + // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. Array MutateArray(Array arr, int* p_lanes) { diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 8dae79929fe8..ce50ed0c45f7 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -192,9 +193,10 @@ TEST(IRF, StmtMutator) { } { - auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}, + CallNode::Extern)); auto res = v(std::move(body)); - CHECK(res.as()->value.as()->args[0].same_as(x)); + CHECK(res.as()->value.as()->args[1].same_as(x)); } { Stmt body = fmakealloc(); diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index c4ac042bdb22..1e4fe6b66830 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -853,8 +853,8 @@ def test_duplicate_adt_cons_defn(): def test_duplicate_global_var(): parse_text( """ - def @id[A](%x: A) -> A { x } - def @id[A](%x: A) -> A { x } + def @id[A](%%x: A) -> A { x } + def @id[A](%%x: A) -> A { x } """ ) diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 525cd6c30736..2fbc82f06ccf 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -202,7 +202,8 @@ def test_reduce_combiner_simplify(): assert tvm.ir.structural_equal(lhs, rhs) # Test that components with side effects are not removed - side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic) + dummy = tvm.ir.GlobalVar("dummy") + side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index 0f00e08f9192..18a98eed0673 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -98,7 +98,7 @@ def test_reinterpret(): nn = 1024 n = tvm.runtime.convert(nn) A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "reinterpret", A(*i)), name='B') + B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B') s = te.create_schedule(B.op) def check_c(): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 0b415b0de6ba..a6a231564033 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -29,12 +29,12 @@ def test_llvm_intrin(): n = tvm.runtime.convert(4) A = ib.pointer("float32", name="A") args = [ - tvm.tir.call_pure_intrin("handle", "tvm_address_of", A[0]), + tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]), 0, 3, 1 ] ib.emit(tvm.tir.Evaluate( tvm.tir.Call( - "int32", "prefetch", args, tvm.tir.Call.Intrinsic))) + "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic))) body = ib.get() mod = tvm.IRModule.from_expr( @@ -738,20 +738,20 @@ def _transform(f, *_): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) def np_float2np_bf16(arr): - ''' Convert a numpy array of float to a numpy array + ''' Convert a numpy array of float to a numpy array of bf16 in uint16''' orig = arr.view('> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(x << y) == '@tir.shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x >> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")' assert str(10 % x) == 'floormod(10, x: int32)' - assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")' + + assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -239,10 +240,10 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' z = te.var('z', 'int32') assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') diff --git a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py index 38529e927d52..61accf271631 100644 --- a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py +++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py @@ -26,20 +26,21 @@ def test_ir_transform(): ib.emit(tvm.tir.call_extern("int32", "TestB", x)) ib.emit(tvm.tir.call_extern("int32", "TestC", x)) body = ib.get() + builtin_call_extern = tvm.ir.Op.get("tir.call_extern") def preorder(op): - if op.name == "TestC": + if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestC": return tvm.tir.const(0, "int32") return None def postorder(op): assert isinstance(op, tvm.tir.Call) - if op.name == "TestA": - return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) + if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestA": + return tvm.tir.call_extern("int32", "TestB", op.args[1] + 1) return op body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) - assert stmt_list[0].value.args[0].name == "TestB" + assert stmt_list[0].value.args[1].args[0].value == "TestB" assert stmt_list[1].value.value == 0 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 77a06022ac70..55a6819aeced 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -116,19 +116,19 @@ def test_legalize(): def to32(v): uint32_v = topi.cast(v, "uint32") uint32_v = tvm.tir.call_pure_intrin( - "uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32")) - return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v) + "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")) + return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v) def to16(v): - uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v) + uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v) rounding_bias = tvm.tir.call_pure_intrin( - "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) rounding_bias = tvm.tir.call_pure_intrin( - "uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) + "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") uint32_v = uint32_v + rounding_bias uint32_v = tvm.tir.call_pure_intrin( - "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) + "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) return topi.cast(uint32_v, 'uint16') def check(fcompute_before, fcompute_after): diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 29a330319622..d7a25ca0156e 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -22,7 +22,7 @@ def test_for(): def device_context(dev_id): ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) return tvm.tir.Call( - "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) + "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) ib = tvm.tir.ir_builder.create() n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index f6583493d646..8469bc953b71 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -17,6 +17,14 @@ import tvm from tvm import te +# register the ops +tvm.ir.register_op_attr("tir.cop.coproc_sync", "TGlobalSymbol", "coproc_sync") +tvm.ir.register_op_attr("tir.cop.coproc_read_barrier", "TGlobalSymbol", "coproc_readb") +tvm.ir.register_op_attr("tir.cop.coproc_write_barrier", "TGlobalSymbol", "coproc_writeb") +tvm.ir.register_op_attr("tir.cop.coproc_dep_push", "TGlobalSymbol", "coproc_dep_push") +tvm.ir.register_op_attr("tir.cop.coproc_dep_pop", "TGlobalSymbol", "coproc_dep_pop") + + def test_coproc_sync(): @tvm.register_func("tvm.info.mem.global.cache") def meminfo_cache(): @@ -26,6 +34,7 @@ def meminfo_cache(): max_simd_bits=32, max_num_bits=128, head_address=tvm.tir.call_extern("handle", "global_cache")) + ib = tvm.tir.ir_builder.create() n = te.size_var("n") cp = te.thread_axis((0, 1), "cop") @@ -43,10 +52,11 @@ def meminfo_cache(): body = stmt.body.body.body blist = tvm.tir.stmt_list(body) - assert(blist[1].value.name == "cop.coproc_read_barrier") + + assert(blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier"))) assert(blist[1].value.args[3].value == 80) - assert(blist[-2].value.name == "cop.coproc_sync") - assert(blist[-1].value.name == "cop.coproc_write_barrier") + assert(blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync"))) + assert(blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier"))) assert(blist[-1].value.args[3].value == 10) @@ -106,9 +116,9 @@ def __check_list(tvm_array, py_list): slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] - assert(push_st.value.name == "cop.coproc_dep_push") + assert(push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push"))) assert(__check_list(push_st.value.args, [2,3])) - assert(pop_st.value.name == "cop.coproc_dep_pop") + assert(pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop"))) assert(__check_list(pop_st.value.args, [2,3])) diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index 0b6b167c8660..cf5863204bfe 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -56,7 +56,7 @@ def test_double_buffer(): f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": + if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): count[0] += 1 tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index c0789c654fbf..4964039a4c14 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -36,7 +36,7 @@ def get_vthread(name): bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) ib.emit(tvm.tir.call_extern("int32", "Run", bbuffer.access_ptr("r"), - tvm.tir.call_pure_intrin("int32", "tvm_context_id"))) + tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id"))) C[i * nthread + tx] = B[i] + 1 return ib.get() diff --git a/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py index 229c11b783a6..9f1104dcc512 100644 --- a/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py +++ b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py @@ -39,8 +39,10 @@ def test_rewrite_Select(): mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a))) aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value - assert yy.name == "tvm_if_then_else" - assert zz.name == "tvm_if_then_else" + builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else") + + assert yy.op.same_as(builtin_if_then_else) + assert yy.op.same_as(builtin_if_then_else) assert isinstance(aa, tvm.tir.Select) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 5fea580fbf5c..468867a425cd 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -125,7 +125,7 @@ def test_flatten_double_buffer(): count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": + if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): count[0] += 1 tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 783b66983c48..3ff6804cf7e0 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -49,7 +49,7 @@ def test_thread_storage_sync(): cuda_target = tvm.target.create("cuda") f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] body_list = tvm.tir.stmt_list(f.body.body.body.body) - assert(body_list[1].value.name == "tvm_storage_sync") + assert(body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))) diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index d7124b6b7e89..a69c9d36c693 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -117,7 +117,7 @@ def test_vectorize_if_then_else(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") with ib.for_range(0, 4, for_type="vectorize") as i: - A[i] = tvm.tir.call_intrin("float32", "tvm_if_then_else", + A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i]) stmt = ib.get() @@ -132,7 +132,7 @@ def test_vectorize_if_then_else(): A = ib.pointer("float32", name="A") with ib.for_range(0, n) as k: with ib.for_range(0, 4, for_type="vectorize") as i: - A[k * 4 + i] = tvm.tir.call_intrin("float32", "tvm_if_then_else", + A[k * 4 + i] = tvm.tir.call_intrin("float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0) stmt = ib.get() diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index b84fbc7722a1..7068b95bec6c 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -25,6 +25,7 @@ #define TOPI_DETAIL_EXTERN_H_ #include +#include #include #include @@ -111,11 +112,11 @@ inline Array make_extern(const Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; @@ -126,7 +127,7 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(DataType::Int(32), static_cast(buf->shape.size())), make_const(buf->dtype, 0), buf->elem_offset}; - return tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, pack_args, + return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args, tvm::tir::CallNode::CallType::Intrinsic); } @@ -140,7 +141,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::Call(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, + return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args, tvm::tir::CallNode::CallType::Intrinsic); } diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index a92d21c27afe..0ec7e4d212bf 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -25,6 +25,7 @@ #define TOPI_ELEMWISE_H_ #include +#include #include #include @@ -309,7 +310,8 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te return compute( x->shape, [&](const Array& i) { - return tvm::tir::Call(type, "reinterpret", {x(i)}, tvm::tir::CallNode::PureIntrinsic); + return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)}, + tvm::tir::CallNode::PureIntrinsic); }, name, tag); } diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index ac1ac45c1b38..f035251a8c29 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -231,8 +231,10 @@ def _instr(index): cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_) else: cnts = tvm.tir.popcount(w_ & x_) - upper_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorhigh', cnts) - lower_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorlow', cnts) + upper_half = tvm.tir.call_pure_intrin( + half_dtype, 'tir.vectorhigh', cnts) + lower_half = tvm.tir.call_pure_intrin( + half_dtype, 'tir.vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, @@ -241,7 +243,7 @@ def _instr(index): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( - full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) + full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) out = tvm.tir.call_llvm_intrin( return_dtype, vpadalu, @@ -261,7 +263,7 @@ def _instr(index): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( - full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) + full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) out = tvm.tir.call_llvm_intrin( return_dtype, vpadalu, diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index bab91578e77e..da9c71a5346b 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -86,11 +86,11 @@ def _instr(index): dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) - re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) - vec_a = tvm.tir.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index f2c1143b5fb8..c98d7e99d3ee 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -38,9 +38,10 @@ def cuda_atomic_add_rule(op): tvm.target.intrin.register_intrin_rule( "cuda", "atomic_add", cuda_atomic_add_rule, override=True) +tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False) def atomic_add(x, y): - return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) + return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y) def get_valid_counts_ir(data, valid_count, out, out_indices, @@ -113,7 +114,7 @@ def get_valid_counts_ir(data, valid_count, out, out_indices, with ib.if_scope( tvm.tir.all(data[tid * elem_length + score_index] > score_threshold, tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): - atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of", + atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of", valid_count[i]), one_count) with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = data[tid * elem_length + k] diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index f713bb216808..5b7e0905de63 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -185,7 +185,7 @@ def argsort_ir(data_buf, out_index_buf): temp_index[0] = index_out[offset] index_out[offset] = index_out[offset + 1] index_out[offset + 1] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) return ib.get() @@ -246,7 +246,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5) with ib.if_scope(iou > nms_threshold): p_out[base_idx + i] = True - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) return ib.get() diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index ddae2bd96135..7181d5721684 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -115,7 +115,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): if indices_out is not None: indices_out[base_idx + tid * axis_mul_after] = \ tvm.tir.generic.cast(tid, indices_out.dtype) - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) idxd = tvm.tir.indexdiv @@ -143,7 +143,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): temp_index[0] = indices_out[offset] indices_out[offset] = indices_out[offset + axis_mul_after] indices_out[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) @@ -235,7 +235,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): temp_index[0] = output[offset] output[offset] = output[offset + axis_mul_after] output[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index 3941c00cc464..c2b7d250293d 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -100,7 +100,7 @@ def intrin_func(ins, outs): BC = outs[0] row = wmma_m * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() @@ -128,7 +128,7 @@ def intrin_func(ins, outs): BC = outs[0] row = wmma_n * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() @@ -156,7 +156,7 @@ def intrin_func(ins, outs): BC = outs[0] row = wmma_m * wmma_n warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', BA.data, wmma_m, wmma_n, wmma_k, warp_index, BC.access_ptr('w'), strides_dst[0], 'row_major')) return ib.get() @@ -207,13 +207,14 @@ def warp_idnex(offset, row, col): def init(): ib = tvm.tir.ir_builder.create() ib.emit( - tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k, + tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', + BC.data, wmma_m, wmma_n, wmma_k, warp_index_C, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', BC.data, warp_index_C, BA.data, warp_index_A, BB.data, warp_index_B, diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index ee8d83dbef07..31de70e92f18 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -88,9 +88,9 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.tir.const(1, "int16x32") pair_reduction = tvm.tir.call_llvm_intrin('int16x32', @@ -174,9 +174,9 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x2") - re_int16 = tvm.tir.call_pure_intrin('int16', 'reinterpret', a_int8) + re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8) vec_ai16 = re_int16.astype('int16x32') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai16) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16) for i in range(4): vec_b = ins[1].vload([i*32, 0], "int8x64") @@ -254,7 +254,7 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_b = ins[1].vload([0, 0], "int8x64") @@ -262,7 +262,7 @@ def _instr(index): llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version - vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'reinterpret', vec_b) + vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b) vec_zero = tvm.tir.const(0, "int32x16") quad_reduction = tvm.tir.call_llvm_intrin('int32x16', 'llvm.x86.avx512.vpdpbusd.512', @@ -270,7 +270,7 @@ def _instr(index): vec_zero, vec_ai32, vec_bi32) else: # Fall back to the normal AVX512 - vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) vec_one = tvm.tir.const(1, "int16x32") pair_reduction = tvm.tir.call_llvm_intrin('int16x32', 'llvm.x86.avx512.pmaddubs.w.512', diff --git a/topi/tests/python/test_topi_basic.py b/topi/tests/python/test_topi_basic.py index 13f1463da7ff..a83ff50bd5b1 100644 --- a/topi/tests/python/test_topi_basic.py +++ b/topi/tests/python/test_topi_basic.py @@ -34,7 +34,7 @@ def test_ewise(): def test_apply(func, name): B = func(A) assert tuple(B.shape) == tuple(A.shape) - assert B.op.body[0].name == name + assert B.op.body[0].op.name == "tir." + name test_apply(topi.exp, "exp") test_apply(topi.erf, "erf") diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index ea980833ae20..6f1e8588fd7c 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -50,11 +50,11 @@ def test_apply( B = func(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: - assert B.op.body[0].name == name + assert B.op.body[0].op.name == "tir." + name a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 # avoid round check too close to boundary if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 + a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4 b_np = f_numpy(a_np) def check_device(device): @@ -89,7 +89,7 @@ def test_isnan( B = topi.isnan(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: - assert B.op.body[0].name == "isnan" + assert B.op.body[0].op.name == "tir.isnan" a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan # avoid round check too close to boundary diff --git a/tutorials/language/intrin_math.py b/tutorials/language/intrin_math.py index 146263dab5d7..65bfd4c38681 100644 --- a/tutorials/language/intrin_math.py +++ b/tutorials/language/intrin_math.py @@ -100,12 +100,15 @@ def my_cuda_math_rule(op): """Customized CUDA intrinsic lowering rule""" assert isinstance(op, tvm.tir.Call) + name = op.op.name + assert name.startswith("tir.") + dispatch_name = name[4:] if op.dtype == "float32": # call float function - return tvm.tir.call_pure_extern("float32", "%sf" % op.name, op.args[0]) + return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0]) elif op.dtype == "float64": # call double function - return tvm.tir.call_pure_extern("float32", op.name, op.args[0]) + return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0]) else: # cannot do translation, return self. return op @@ -132,7 +135,7 @@ def my_cuda_math_rule(op): def mylog(x): """customized log intrinsic function""" - return tvm.tir.call_pure_intrin(x.dtype, "mylog", x) + return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x) def my_cuda_mylog_rule(op): @@ -144,7 +147,8 @@ def my_cuda_mylog_rule(op): else: return op - +# new op registration is triggered by registering an attribute of the op +tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True) tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) n = te.var("n") diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py index cd40a91ac6c8..4b2823c08d03 100644 --- a/tutorials/optimize/opt_conv_tensorcore.py +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -163,7 +163,7 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', BC.data, n, n, n, BC.elem_offset // 256, BA.access_ptr('r'), n, 'row_major')) return ib.get() @@ -190,12 +190,12 @@ def intrin_func(ins, outs): def init(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', BC.data, BC.elem_offset // 256, BA.data, BA.elem_offset // 256, BB.data, BB.elem_offset // 256, @@ -218,7 +218,7 @@ def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', BA.data, n, n, n, BA.elem_offset // 256, BC.access_ptr('w'), n, 'row_major')) return ib.get() diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index e68f098ba53f..947c583ed55f 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -77,9 +77,9 @@ class DevContext(object): def __init__(self, env): self.vta_axis = te.thread_axis("vta") self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp") - ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle") + ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle") self.command_handle = tvm.tir.Call( - "handle", "tvm_thread_context", [ctx], + "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) self.DEBUG_NO_SYNC = False env._dev_ctx = self @@ -298,6 +298,7 @@ def coproc_sync(op): tvm.runtime.const(1<<31, dtype="uint32")) + @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") def coproc_dep_push(op): return tvm.tir.call_extern( @@ -313,6 +314,15 @@ def coproc_dep_pop(op): get_env().dev.command_handle, op.args[0], op.args[1]) +# register a dummy into to trigger registration of the ops +# change the info to lowering rule later. +tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False) +tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False) +tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False) + +tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush") +tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle") + def _init_env(): """Initialize the default global env""" diff --git a/vta/python/vta/intrin.py b/vta/python/vta/intrin.py index 8532ffa318b5..897bbcba4cd3 100644 --- a/vta/python/vta/intrin.py +++ b/vta/python/vta/intrin.py @@ -82,16 +82,16 @@ def instr(index): irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index in (0, 2): - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin( + "int32", "tir.vta.uop_push", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin( + "int32", "tir.vta.uop_push", 0, 1, dout.access_ptr("rw", "int32"), 0, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 207f784b5885..e92b178a5be6 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -59,11 +59,12 @@ def _fold_outermost_loop(body): loop_var = stmt.loop_var gemm_offsets = [None, None, None] fail = [False] + builtin_uop_push = tvm.ir.Op.get("tir.vta.uop_push") def _post_order(op): assert isinstance(op, tvm.tir.Call) base_args = 2 - if op.name == "VTAUopPush": + if op.op.same_as(builtin_uop_push): args = [] args += op.args[:base_args] for i in range(3): @@ -81,8 +82,8 @@ def _post_order(op): gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args+3:] - return tvm.tir.call_extern("int32", "VTAUopPush", *args) - if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): + return tvm.tir.call_intrin("int32", builtin_uop_push, *args) + if op.op.name not in ("tir.vta.command_handle", "tir.tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op @@ -643,7 +644,7 @@ def _do_fold(op): dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, @@ -658,7 +659,7 @@ def _do_fold(op): tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_buffer], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) return inner else: conv_call, data_call, kernel_call = calls[-3:] @@ -678,7 +679,7 @@ def _do_fold(op): irb.scope_attr( dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), @@ -691,19 +692,19 @@ def _do_fold(op): 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) args = kernel_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dwgt, kernel_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) args = data_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dinp, pad_data_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) return inner return None @@ -833,11 +834,11 @@ def _flatten_loop(src_coeff, dst_coeff, extents): lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Call): - if loop_body.value.name == 'shift_left': + if loop_body.value.op.name == 'tir.shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = analyzer.simplify(-loop_body.value.args[1]) - elif loop_body.value.name == 'shift_right': + elif loop_body.value.op.name == 'tir.shift_right': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = loop_body.value.args[1] @@ -942,8 +943,8 @@ def _flatten_loop(src_coeff, dst_coeff, extents): "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0)) use_imm = int(use_imm) - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", + irb.emit(tvm.tir.call_intrin( + "int32", "tir.vta.uop_push", 1, 0, dst_coeff[len(dst_coeff)-1], src_coeff[len(src_coeff)-1],