From d770741a262f8231f63ac0e4f8645874ba54d7bb Mon Sep 17 00:00:00 2001 From: Taikang Hu Date: Tue, 23 Jun 2020 10:33:40 +0800 Subject: [PATCH] Revert "[TIR][REFACTOR][API-CHANGE] Change Call.name to Call.op(RelayExpr) (#5863)" This reverts commit 82d157f0b83ae17fde7bbfca14110aa2f2b80b61. --- include/tvm/relay/expr.h | 2 +- include/tvm/tir/builtin.h | 540 ------------------ include/tvm/tir/expr.h | 399 ++++++++++++- include/tvm/tir/function.h | 4 +- include/tvm/tir/op.h | 8 +- include/tvm/tir/op_attr_types.h | 48 -- include/tvm/tir/stmt.h | 4 +- python/tvm/contrib/nvcc.py | 3 +- python/tvm/target/datatype.py | 11 +- python/tvm/target/intrin.py | 10 +- python/tvm/te/hybrid/calls.py | 2 +- python/tvm/tir/expr.py | 22 +- python/tvm/tir/ir_builder.py | 2 +- python/tvm/tir/op.py | 81 ++- src/arith/const_int_bound.cc | 6 +- src/arith/ir_mutator_with_analyzer.cc | 9 +- src/arith/modular_set.cc | 3 +- src/arith/pattern_match.h | 45 +- src/arith/rewrite_simplify.cc | 10 +- src/contrib/hybrid/codegen_hybrid.cc | 31 +- src/ir/op.cc | 2 +- src/printer/tir_text_printer.cc | 15 +- src/relay/transforms/pass_util.h | 4 +- src/target/intrin_rule.h | 16 +- src/target/llvm/codegen_arm.cc | 15 +- src/target/llvm/codegen_cpu.cc | 66 +-- src/target/llvm/codegen_cpu.h | 5 +- src/target/llvm/codegen_llvm.cc | 76 ++- src/target/llvm/codegen_llvm.h | 10 +- src/target/llvm/codegen_nvptx.cc | 8 +- src/target/llvm/codegen_x86_64.cc | 4 +- src/target/llvm/intrin_rule_llvm.cc | 20 +- src/target/llvm/intrin_rule_llvm.h | 6 +- src/target/llvm/intrin_rule_nvptx.cc | 17 +- src/target/llvm/intrin_rule_rocm.cc | 34 +- src/target/source/codegen_c.cc | 186 +++--- src/target/source/codegen_c.h | 17 - src/target/source/codegen_c_host.cc | 6 +- src/target/source/codegen_cuda.cc | 123 ++-- src/target/source/codegen_cuda.h | 7 - src/target/source/codegen_metal.cc | 2 +- src/target/source/intrin_rule_cuda.cc | 58 +- src/target/source/intrin_rule_opencl.cc | 4 +- src/target/spirv/codegen_spirv.cc | 29 +- src/target/spirv/intrin_rule_spirv.cc | 4 +- src/target/stackvm/codegen_stackvm.cc | 49 +- src/target/stackvm/codegen_stackvm.h | 4 - src/te/autodiff/jacobian.cc | 34 +- src/te/operation/compute_op.cc | 3 +- src/te/operation/cross_thread_reduction.cc | 4 +- src/te/operation/extern_op.cc | 2 +- src/te/operation/tensor_compute_op.cc | 5 +- src/te/operation/tensorize.cc | 4 +- ...hedule_postproc_rewrite_for_tensor_core.cc | 33 +- src/tir/analysis/verify_memory.cc | 3 +- src/tir/ir/buffer.cc | 3 +- src/tir/ir/expr.cc | 45 +- src/tir/ir/expr_functor.cc | 2 +- src/tir/{op => ir}/op.cc | 172 +----- src/tir/ir/stmt.cc | 8 - src/tir/op/builtin.cc | 155 ----- src/tir/op/runtime.cc | 39 -- src/tir/transforms/arg_binder.cc | 27 +- src/tir/transforms/bf16_legalize.cc | 5 +- src/tir/transforms/bound_checker.cc | 3 +- src/tir/transforms/combine_context_call.cc | 11 +- src/tir/transforms/coproc_sync.cc | 22 +- src/tir/transforms/inject_virtual_thread.cc | 9 +- src/tir/transforms/ir_util.h | 14 +- src/tir/transforms/loop_partition.cc | 9 +- .../lower_device_storage_access_info.cc | 3 +- src/tir/transforms/lower_intrin.cc | 20 +- src/tir/transforms/lower_thread_allreduce.cc | 19 +- src/tir/transforms/lower_tvm_builtin.cc | 58 +- src/tir/transforms/lower_warp_memory.cc | 5 +- src/tir/transforms/make_packed_api.cc | 7 +- src/tir/transforms/narrow_datatype.cc | 35 +- src/tir/transforms/rewrite_unsafe_select.cc | 7 +- src/tir/transforms/split_host_device.cc | 3 +- src/tir/transforms/storage_access.cc | 6 +- src/tir/transforms/storage_flatten.cc | 10 +- src/tir/transforms/storage_rewrite.cc | 9 +- .../transforms/tensorcore_infer_fragment.cc | 8 +- src/tir/transforms/thread_storage_sync.cc | 11 +- src/tir/transforms/vectorize_loop.cc | 18 +- tests/cpp/ir_functor_test.cc | 6 +- tests/python/relay/test_ir_parser.py | 4 +- .../unittest/test_arith_canonical_simplify.py | 3 +- .../unittest/test_target_codegen_c_host.py | 2 +- .../unittest/test_target_codegen_llvm.py | 12 +- .../test_target_codegen_static_init.py | 2 +- .../unittest/test_te_schedule_tensor_core.py | 8 +- tests/python/unittest/test_tir_constructor.py | 6 +- tests/python/unittest/test_tir_nodes.py | 27 +- .../test_tir_stmt_functor_ir_transform.py | 9 +- .../test_tir_transform_bf16_legalize.py | 12 +- ...test_tir_transform_combine_context_call.py | 2 +- .../test_tir_transform_coproc_sync.py | 20 +- ...test_tir_transform_inject_double_buffer.py | 2 +- ...est_tir_transform_inject_virtual_thread.py | 2 +- ...est_tir_transform_rewrite_unsafe_select.py | 6 +- .../test_tir_transform_storage_flatten.py | 2 +- .../test_tir_transform_thread_sync.py | 2 +- .../unittest/test_tir_transform_vectorize.py | 4 +- topi/include/topi/detail/extern.h | 9 +- topi/include/topi/elemwise.h | 4 +- topi/python/topi/arm_cpu/bitserial_conv2d.py | 10 +- topi/python/topi/arm_cpu/tensor_intrin.py | 4 +- topi/python/topi/cuda/nms.py | 5 +- topi/python/topi/cuda/rcnn/proposal.py | 4 +- topi/python/topi/cuda/sort.py | 6 +- topi/python/topi/cuda/tensor_intrin.py | 11 +- topi/python/topi/x86/tensor_intrin.py | 14 +- topi/tests/python/test_topi_basic.py | 2 +- topi/tests/python/test_topi_math.py | 6 +- tutorials/language/intrin_math.py | 12 +- tutorials/optimize/opt_conv_tensorcore.py | 8 +- vta/python/vta/environment.py | 14 +- vta/python/vta/intrin.py | 8 +- vta/python/vta/transform.py | 27 +- 120 files changed, 1168 insertions(+), 1965 deletions(-) delete mode 100644 include/tvm/tir/builtin.h delete mode 100644 include/tvm/tir/op_attr_types.h rename src/tir/{op => ir}/op.cc (82%) delete mode 100644 src/tir/op/builtin.cc delete mode 100644 src/tir/op/runtime.cc diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 3c156dfd74812..779bcc34272f6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -226,7 +226,7 @@ class CallNode : public ExprNode { /*! * \brief The operator(function) being invoked * - * - It can be tvm::Op which corresponds to the primitive operators. + * - It can be relay::Op which corresponds to the primitive operators. * - It can also be user defined functions (Function, GlobalVar, Var). */ Expr op; diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h deleted file mode 100644 index 96526ccfcfb26..0000000000000 --- a/include/tvm/tir/builtin.h +++ /dev/null @@ -1,540 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/builtin.h - * \brief TIR builtin intrinsics. - * - * TIR builtin intrinsics are stored as tvm:Op. - * They are processed in the same way as we process Ops. - * - * It is not necessary to create a function for every Op, - * as we can obtain them through Op::Get. - * - * This file contains the most commonly used intrinsics or - * those that have special semantics and need compiler support. - */ -#ifndef TVM_TIR_BUILTIN_H_ -#define TVM_TIR_BUILTIN_H_ - -#include -#include - -namespace tvm { -namespace tir { - -/*! \brief Collection of builtin intrinsics as ops */ -namespace builtin { -/*! - * \brief Reinterpret the value using the target type. - */ -TVM_DLL const Op& reinterpret(); - -/*! - * \brief Marks a condition is likely going to happen. - */ -TVM_DLL const Op& likely(); - -/*! - * \brief Bitwise and operator. - */ -TVM_DLL const Op& bitwise_and(); - -/*! - * \brief Bitwise or operator. - */ -TVM_DLL const Op& bitwise_or(); - -/*! - * \brief Bitwise xor operator. - */ -TVM_DLL const Op& bitwise_xor(); - -/*! - * \brief Bitwise not operator. - */ -TVM_DLL const Op& bitwise_not(); - -/*! - * \brief Left shift - */ -TVM_DLL const Op& shift_left(); - -/*! - * \brief Right shift - */ -TVM_DLL const Op& shift_right(); - -/*! - * \brief See pesudo code - * - * Construct a big uint that may not be representable by int64 - * - * Expr large_uint_imm(uint32_t v0, uin32_t v1) { - * return (v1 << 32) | v0; - * } - */ -TVM_DLL const Op& large_uint_imm(); - -/*! - * \brief See pesudo code - * - * Handle address_of(Load *op) { - * return &op->buffer_var[index]; - * } - */ -TVM_DLL const Op& address_of(); - -/*! - * \brief Same as select, used for unsafe memory access. - * - * Type tvm_if_then_else(cond, a, b) { - * return cond ? a : b; - * } - */ -TVM_DLL const Op& if_then_else(); - -/*! - * \brief See pesudo code - * - * bool isnullptr(void* handle) { - * return handle == nullptr - * } - */ -TVM_DLL const Op& isnullptr(); - -/*! - * \brief Check if value is nan - */ -TVM_DLL const Op& isnan(); - -/*! - * \brief Popcount - */ -TVM_DLL const Op& popcount(); - -/*! - * \brief Fused multiply add - * - * Type fma(a, b, c) { - * return a * b + c; - * } - */ -TVM_DLL const Op& fma(); - -/*! - * \brief Call an extern C function with given name - * and signature from the types of args in the runtime environment. - * - * Type call_extern(name, args...) { - * return dlsym(name)(args...); - * } - * - * \note This intrinsic does not provide any type checking, - * and is main used for backward compatibility reasons. - * Always consider use pre-registered and typed tvm::Op first. - */ -TVM_DLL const Op& call_extern(); - -/*! - * \brief Call an LLVM intrinsic with a given intrinsic id - * and signature from the types of args in the runtime environment. - * - * Type call_llvm_intrin(intrin_id, args...) { - * return dlsym(name)(args...); - * } - * - * \note This op does not provide any type checking. - */ -TVM_DLL const Op& call_llvm_intrin(); - -/*! - * \brief Call an SPIRV GLSL450 intrinsic. - * - * Type call_spirv_glsl450(intrin_id, args...) { - * return dlsym(name)(args...); - * } - * - * \note This op does not provide any type checking. - */ -TVM_DLL const Op& call_spirv_glsl450(); - -// TODO(tvm-team) revisit the builtins below -// some of them can simply become ops with special codegen attr. -/*! - * \brief Prefetch a cacheline - */ -TVM_DLL const Op& prefetch(); - -/*! - * \brief Get head access address with memory access pattern info. - * - * This operator also marks range of the memory access - * The offset and extent are in unit of the DType(including vectorization factor). - * rw_mask is a bit_mask setting whether the access is a read(1) or write(2). - * The access is assume to happen in the current expression. - * - * PtrType tvm_access_ptr(Expr dtype, DType* data, - * int offset, int extent, - * int rw_mask) { - * // DType == dtype.type(); - * return &data[offset]; - * } - */ -TVM_DLL const Op& tvm_access_ptr(); - -/*! - * \brief Create a function local static handle that iniitalizes to nullptr. - * can be used to cache function local static resources. - */ -TVM_DLL const Op& tvm_static_handle(); - -/*! - * \brief Return a unique context id, used for hint of workspace separation. - * Different context id ganrantees not having overlapping workspace. - */ -TVM_DLL const Op& tvm_context_id(); - -/*! - * \brief tvm_tuple is not an actual function and cannot codegen. - * It is used to represent tuple structure in value field of AttrStmt, - * for the sake of giving hint to optimization. - * - * Handle tvm_tuple(value0, value1, ..., value_n); - */ -TVM_DLL const Op& tvm_tuple(); - -/*! - * \brief See pesudo code - * - * Type tvm_struct_get(StructType* arr, int index, int field_id) { - * return arr[index]->field; - * } - * \sa TVMStructFieldKind - */ -TVM_DLL const Op& tvm_struct_get(); - -/*! - * \brief See pesudo code - * - * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) { - * arr[index]->field = value; - * } - * \sa TVMStructFieldKind - */ -TVM_DLL const Op& tvm_struct_set(); - -/*! - * \brief See pesudo code - * - * void tvm_throw_last_error() { - * throw TVMGetLastError(); - * } - */ -TVM_DLL const Op& tvm_throw_last_error(); - -/*! - * \brief See pesudo code - * - * dtype in {shape, array, arg_value, arg_tcode} - * - * Handle tvm_stack_alloca(string dtype, int num) { - * return new on stack dtype[num]; - * } - */ -TVM_DLL const Op& tvm_stack_alloca(); - -/*! - * \brief Allocate a shape tuple on stack, return the handle. - * - * Handle tvm_stack_make_shape(list args) { - * ret = alloca stack int64_t[len(args)]; - * for i in range(len(args)): - * ret[i] = args[i] - * return &ret[0]; - * } - */ -TVM_DLL const Op& tvm_stack_make_shape(); - -/*! - * \brief Allocate a NDArray(DLTensor) on stack, return the handle. - * - * Type tvm_stack_make_array(Expr data, - * Expr shape, - * Expr strides, - * Expr ndim, - * Expr dtype, - * Expr elem_offset) { - * ret = alloca stack DLTensor(); - * ret->data = data; - * ret->shape = shape; - * ret->strides = strides != 0 ? strides : nullptr; - * ret->ndim = ndim; - * ret->dtype = dtype.type(); - * ret->byte_offset = elem_offset * sizeof(dtype); - * return ret; - * } - */ -TVM_DLL const Op& tvm_stack_make_array(); - -/*! - * \brief See pesudo code - * - * int tvm_call_packed(name, TVMValue* args) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; - * } - */ -TVM_DLL const Op& tvm_call_packed(); - -/*! - * \brief See pesudo code - * - * int tvm_call_trace_packed(name, TVMValue* args) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * (*f)(args, type_code_of(args), len(args)); - * return 0; - * } - */ -TVM_DLL const Op& tvm_call_trace_packed(); - -/*! - * \brief See pesudo code - * Mark the content as thread local context, can get optimized - * by only call the call once at thread start. - * - * Do not allow nesting(getting a thread context from another). - * - * Handle tvm_thread_context(Expr call) { - * return call; - * } - */ -TVM_DLL const Op& tvm_thread_context(); - -/*! - * \brief Lowered version of call packed, the space of value and - * type codes are explicitly allocated. - * - * int tvm_call_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); - * } - */ -TVM_DLL const Op& tvm_call_packed_lowered(); - -/*! - * \brief Lowered version of trace intrinsic, the space of value and - * type codes are explicitly allocated. The return value is the - * (end - 1) value on the stack. - * - * int tvm_call_trace_packed_lowered(name, - * TVMValue* value_stack, - * int* tcode_stack, - * int begin, - * int end) { - * ModuleNode* env = GetCurrentEnv(); - * const PackedFunc* f = env->GetFuncFromEnv(name); - * f->CallPacked(TVMArgs(value_stack[begin:end], - * tcode_stack[begin:end]), - * TVMRetValue(value_stack + end, tcode_stack + end)); - * } - */ -TVM_DLL const Op& tvm_call_trace_packed_lowered(); - -/*! - * \brief See pseudo code - * - * int tvm_storage_sync(std::string storage_scope) { - * __sync(storage_scope); - * return 0; - * } - */ -TVM_DLL const Op& tvm_storage_sync(); - -/*! - * \brief See pseudo code - * - * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id); - * } - * - * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id - offset); - * } - * - * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { - * return (value passed in by warp indicated by this_warp_id + offset); - * } - * - * unsigned tvm_warp_activemask() { - * return (32-bit mask of currently active threads in the calling warp); - * } - * - * Parameter warp_id indicates the source thread ID in a warp. - * - * Parameter offset indicates the relative distance to this_warp_id. - * - * Parameter width indicates the number of threads involved in one - * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, - * __shfl_down_sync and __activemask. - * - * Parameter warp_size is the size of a warp, which helps a backend - * to determine wheter the width paramter is legal. - * - */ -TVM_DLL const Op& tvm_warp_shuffle(); -TVM_DLL const Op& tvm_warp_shuffle_up(); -TVM_DLL const Op& tvm_warp_shuffle_down(); -TVM_DLL const Op& tvm_warp_activemask(); - -/*! - * \brief Initialize the global barrier. - * Call this at beginning of kernel that need global barrier. - */ -TVM_DLL const Op& tvm_global_barrier_kinit(); - -/*! - * \brief See pesudo code - * - * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, - * Var reduce_temp0, .., Var thread_idx1, ...) { - * // constraint by the other thread_idx remain the same. - * // reduce_temp is used to save intermediate result. - * reduce_temp0, ... = reduce(combiner, source0, ..., cond - * over [thread_idx1, thread_idx2] passed by any caller) - * } - */ -TVM_DLL const Op& tvm_thread_allreduce(); - -// TODO(tvm-team) TensorCore specific intrinsics should be directly registered under -// cuda. namespace and used through op. -/*! - * \brief tvm intrinsic for tensor core load operators. - * - * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment. - * // Determine fragment layout(column-major or row major) by layout. - * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. - * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); - * } - */ -TVM_DLL const Op& tvm_load_matrix_sync(); - -/*! - * \brief tvm intrinsic for tensor core mma_sync operators. - * - * void tvm_mma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -TVM_DLL const Op& tvm_mma_sync(); - -/*! - * \brief tvm intrinsic for tensor core bmma_sync operators. - * - * void tvm_bmma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -TVM_DLL const Op& tvm_bmma_sync(); - -/*! - * \brief tvm intrinsic for tensor core fill_fragment operators. - * - * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr value) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::fill_fragment(fragment[index], value); - * } - */ -TVM_DLL const Op& tvm_fill_fragment(); - -/*! - * \brief tvm intrinsic for tensor core store operators. - * - * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); - * } - */ -TVM_DLL const Op& tvm_store_matrix_sync(); - -// TODO(tvm-team) replace the usage of the vector operations by Shuffle. -/*! - * \brief Get the high level half of the vector - */ -TVM_DLL const Op& vectorhigh(); - -/*! - * \brief Get the low-level half of the vector - */ -TVM_DLL const Op& vectorlow(); - -/*! - * \brief Concat two vectors. - */ -TVM_DLL const Op& vectorcombine(); - -/*! \brief The kind of structure field info used in intrinsic */ -enum TVMStructFieldKind : int { - // array head address - kArrAddr, - kArrData, - kArrShape, - kArrStrides, - kArrNDim, - kArrTypeCode, - kArrTypeBits, - kArrTypeLanes, - kArrByteOffset, - kArrDeviceId, - kArrDeviceType, - kArrKindBound_, - // TVMValue field - kTVMValueContent, - kTVMValueKindBound_ -}; -} // namespace builtin -} // namespace tir -} // namespace tvm -#endif // TVM_TIR_BUILTIN_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index a51f709840111..1518d1ff548ea 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -888,14 +888,8 @@ class CallNode : public PrimExprNode { /*! \brief Intrinsic functions that are pure. */ PureIntrinsic = 5 }; - /*! - * \brief The operator(function) being invoked - * - * - It can be tvm::Op which corresponds to the primitive operators(intrinsics). - * - It can also be another function in the IRModule (GlobalVar). - */ - RelayExpr op; - + /*! \brief The name of the function/intrinsic. */ + String name; /*! \brief The arguments. */ Array args; /*! \brief Type of calls. */ @@ -903,19 +897,19 @@ class CallNode : public PrimExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); - v->Visit("op", &op); + v->Visit("name", &name); v->Visit("args", &args); v->Visit("call_type", &call_type); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args) && + return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && equal(call_type, other->call_type); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); - hash_reduce(op); + hash_reduce(name); hash_reduce(args); hash_reduce(call_type); } @@ -923,8 +917,37 @@ class CallNode : public PrimExprNode { /*! \return Whether call node is pure. */ bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } + /*! + * \return Whether call node corresponds to a defined intrinsic. + * \param intrin_name The name of the intrinsic. + */ + bool is_intrinsic(const char* intrin_name) const { + return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name); + } + + /*! \return Whether call node can be vectorized. */ + bool is_vectorizable() const; + static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); + + // Build-in intrinsics + static constexpr const char* reinterpret = "reinterpret"; + static constexpr const char* bitwise_and = "bitwise_and"; + static constexpr const char* bitwise_not = "bitwise_not"; + static constexpr const char* bitwise_xor = "bitwise_xor"; + static constexpr const char* bitwise_or = "bitwise_or"; + static constexpr const char* shift_left = "shift_left"; + static constexpr const char* shift_right = "shift_right"; + static constexpr const char* popcount = "popcount"; + static constexpr const char* likely = "likely"; + static constexpr const char* prefetch = "prefetch"; + static constexpr const char* isnan = "isnan"; + static constexpr const char* isfinite = "isfinite"; + static constexpr const char* isinf = "isinf"; + + /*! \brief Vectorizable intrinsic list. */ + static const char* vectorizable_intrinsics[]; }; /*! @@ -935,7 +958,7 @@ class Call : public PrimExpr { public: using CallType = CallNode::CallType; - TVM_DLL Call(DataType dtype, RelayExpr op, Array args, CallType call_type); + TVM_DLL Call(DataType dtype, String name, Array args, CallType call_type); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); }; @@ -1144,6 +1167,358 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } return ret; } + +/*! \brief namespace of TVM Intrinsic functions */ +namespace intrinsic { +/*! + * \brief See pesudo code + * + * Construct a big uint that may not be representable by int64 + * + * Expr tvm_large_uint_imm(uint32_t v0, uin32_t v1) { + * return (v1 << 32) | v0; + * } + */ +constexpr const char* tvm_large_uint_imm = "tvm_large_uint_imm"; +/*! + * \brief See pesudo code + * + * Handle tvm_address_of(Load *op) { + * return &op->buffer_var[index]; + * } + */ +constexpr const char* tvm_address_of = "tvm_address_of"; +/*! + * \brief Same as select, used for unsafe memory access. + * + * Type tvm_if_then_else(cond, a, b) { + * return cond ? a : b; + * } + */ +constexpr const char* tvm_if_then_else = "tvm_if_then_else"; +/*! + * \brief Get head access address with memory access pattern info. + * + * This operator also marks range of the memory access + * The offset and extent are in unit of the DType(including vectorization factor). + * rw_mask is a bit_mask setting whether the access is a read(1) or write(2). + * The access is assume to happen in the current expression. + * + * PtrType tvm_access_ptr(Expr dtype, DType* data, + * int offset, int extent, + * int rw_mask) { + * // DType == dtype.type(); + * return &data[offset]; + * } + */ +constexpr const char* tvm_access_ptr = "tvm_access_ptr"; +/*! + * \brief Create a function local static handle that iniitalizes to nullptr. + * can be used to cache function local static resources. + */ +constexpr const char* tvm_static_handle = "tvm_static_handle"; +/*! + * \brief Return a unique context id, used for hint of workspace separation. + * Different context id ganrantees not having overlapping workspace. + */ +constexpr const char* tvm_context_id = "tvm_context_id"; +/*! + * \brief tvm_tuple is not an actual function and cannot codegen. + * It is used to represent tuple structure in value field of AttrStmt, + * for the sake of giving hint to optimization. + * + * Handle tvm_tuple(value0, value1, ..., value_n); + */ +constexpr const char* tvm_tuple = "tvm_tuple"; +/*! + * \brief See pesudo code + * + * Type tvm_struct_get(StructType* arr, int index, int field_id) { + * return arr[index]->field; + * } + * \sa TVMStructFieldKind + */ +constexpr const char* tvm_struct_get = "tvm_struct_get"; +/*! + * \brief See pesudo code + * + * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) { + * arr[index]->field = value; + * } + * \sa TVMStructFieldKind + */ +constexpr const char* tvm_struct_set = "tvm_struct_set"; +/*! + * \brief See pesudo code + * + * bool tvm_handle_is_null(void* handle) { + * return handle == nullptr + * } + */ +constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; +/*! + * \brief See pesudo code + * + * void tvm_throw_last_error() { + * throw TVMGetLastError(); + * } + */ +constexpr const char* tvm_throw_last_error = "tvm_throw_last_error"; +/*! + * \brief See pesudo code + * + * dtype in {shape, array, arg_value, arg_tcode} + * + * Handle tvm_stack_alloca(string dtype, int num) { + * return new on stack dtype[num]; + * } + */ +constexpr const char* tvm_stack_alloca = "tvm_stack_alloca"; +/*! + * \brief Allocate a shape tuple on stack, return the handle. + * + * Handle tvm_stack_make_shape(list args) { + * ret = alloca stack int64_t[len(args)]; + * for i in range(len(args)): + * ret[i] = args[i] + * return &ret[0]; + * } + */ +constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape"; +/*! + * \brief Allocate a NDArray(DLTensor) on stack, return the handle. + * + * Type tvm_stack_make_array(Expr data, + * Expr shape, + * Expr strides, + * Expr ndim, + * Expr dtype, + * Expr elem_offset) { + * ret = alloca stack DLTensor(); + * ret->data = data; + * ret->shape = shape; + * ret->strides = strides != 0 ? strides : nullptr; + * ret->ndim = ndim; + * ret->dtype = dtype.type(); + * ret->byte_offset = elem_offset * sizeof(dtype); + * return ret; + * } + */ +constexpr const char* tvm_stack_make_array = "tvm_stack_make_array"; +/*! + * \brief See pesudo code + * + * int tvm_call_packed(name, TVMValue* args) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * (*f)(args, type_code_of(args), len(args)); + * return 0; + * } + */ +constexpr const char* tvm_call_packed = "tvm_call_packed"; +/*! + * \brief See pesudo code + * + * int tvm_call_trace_packed(name, TVMValue* args) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * (*f)(args, type_code_of(args), len(args)); + * return 0; + * } + */ +constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed"; +/*! + * \brief See pesudo code + * Mark the content as thread local context, can get optimized + * by only call the call once at thread start. + * + * Do not allow nesting(getting a thread context from another). + * + * Handle tvm_thread_context(Expr call) { + * return call; + * } + */ +constexpr const char* tvm_thread_context = "tvm_thread_context"; +/*! + * \brief Lowered version of call packed, the space of value and + * type codes are explicitly allocated. + * + * int tvm_call_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * f->CallPacked(TVMArgs(value_stack[begin:end], + * tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; +/*! + * \brief Lowered version of trace intrinsic, the space of value and + * type codes are explicitly allocated. The return value is the + * (end - 1) value on the stack. + * + * int tvm_call_trace_packed_lowered(name, + * TVMValue* value_stack, + * int* tcode_stack, + * int begin, + * int end) { + * ModuleNode* env = GetCurrentEnv(); + * const PackedFunc* f = env->GetFuncFromEnv(name); + * f->CallPacked(TVMArgs(value_stack[begin:end], + * tcode_stack[begin:end]), + * TVMRetValue(value_stack + end, tcode_stack + end)); + * } + */ +constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered"; +/*! + * \brief See pseudo code + * + * int tvm_storage_sync(std::string storage_scope) { + * __sync(storage_scope); + * return 0; + * } + */ +constexpr const char* tvm_storage_sync = "tvm_storage_sync"; + +/*! + * \brief See pseudo code + * + * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id); + * } + * + * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id - offset); + * } + * + * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id + offset); + * } + * + * unsigned tvm_warp_activemask() { + * return (32-bit mask of currently active threads in the calling warp); + * } + * + * Parameter warp_id indicates the source thread ID in a warp. + * + * Parameter offset indicates the relative distance to this_warp_id. + * + * Parameter width indicates the number of threads involved in one + * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, + * __shfl_down_sync and __activemask. + * + * Parameter warp_size is the size of a warp, which helps a backend + * to determine wheter the width paramter is legal. + * + */ +constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; +constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up"; +constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down"; +constexpr const char* tvm_warp_activemask = "tvm_warp_activemask"; + +/*! + * \brief Initialize the global barrier. + * Call this at beginning of kernel that need global barrier. + */ +constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; +/*! + * \brief See pesudo code + * + * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, + * Var reduce_temp0, .., Var thread_idx1, ...) { + * // constraint by the other thread_idx remain the same. + * // reduce_temp is used to save intermediate result. + * reduce_temp0, ... = reduce(combiner, source0, ..., cond + * over [thread_idx1, thread_idx2] passed by any caller) + * } + */ +constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; +/*! + * \brief tvm intrinsic for tensor core load operators. + * + * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment. + * // Determine fragment layout(column-major or row major) by layout. + * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. + * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); + * } + */ +constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync"; +/*! + * \brief tvm intrinsic for tensor core mma_sync operators. + * + * void tvm_mma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +constexpr const char* tvm_mma_sync = "tvm_mma_sync"; +/*! + * \brief tvm intrinsic for tensor core bmma_sync operators. + * + * void tvm_bmma_sync(Var fragment_d, Expr index_d, + * Var fragment_a, Expr index_a, + * Var fragment_b, Expr index_b, + * Var fragment_c, Expr index_c) { + * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], + * fragment_b[index_b], fragment_c[index_c]); + * } + */ +constexpr const char* tvm_bmma_sync = "tvm_bmma_sync"; +/*! + * \brief tvm intrinsic for tensor core fill_fragment operators. + * + * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr value) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::fill_fragment(fragment[index], value); + * } + */ +constexpr const char* tvm_fill_fragment = "tvm_fill_fragment"; +/*! + * \brief tvm intrinsic for tensor core store operators. + * + * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, + * Expr index, Expr buffer_ptr, Expr stride, + * StringImm layout) { + * // m, n, k are the shape of wmma fragment + * // fragments must be in 'wmma.accumulator' scope. + * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); + * } + */ +constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync"; + +/*! \brief The kind of structure field info used in intrinsic */ +enum TVMStructFieldKind : int { + // array head address + kArrAddr, + kArrData, + kArrShape, + kArrStrides, + kArrNDim, + kArrTypeCode, + kArrTypeBits, + kArrTypeLanes, + kArrByteOffset, + kArrDeviceId, + kArrDeviceType, + kArrKindBound_, + // TVMValue field + kTVMValueContent, + kTVMValueKindBound_ +}; +} // namespace intrinsic + } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index caddd99eeb2cf..919391e36b96c 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -87,6 +87,8 @@ class PrimFuncNode : public BaseFuncNode { * While we could have express parameter unpacking and constraint using * normal statements, making buffer_map as first class citizen of PrimFunc * will make program analysis much easier. + * + * \note This field can be nullptr */ Map buffer_map; @@ -142,7 +144,7 @@ class PrimFunc : public BaseFunc { * \param attrs Additional function attributes. */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), - Map buffer_map = Map(), + Map buffer_map = NullValue>(), DictAttrs attrs = NullValue()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 286b6d75cb828..2948bb2cc20e1 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -28,7 +28,6 @@ #ifndef TVM_TIR_OP_H_ #define TVM_TIR_OP_H_ -#include #include #include #include @@ -553,10 +552,9 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + return tir::Call(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ } TVM_DECLARE_INTRIN_UNARY(exp); diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h deleted file mode 100644 index d7c13500d90ef..0000000000000 --- a/include/tvm/tir/op_attr_types.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/op_attr_types.h - * \brief Attribute types in the Op registry for TIR ops. - * - * These attributes can be set via OpRegEntry::set_attr - * - * \sa tvm/ir/op.h - */ -#ifndef TVM_TIR_OP_ATTR_TYPES_H_ -#define TVM_TIR_OP_ATTR_TYPES_H_ - -#include - -namespace tvm { -namespace tir { - -/*! - * \brief Global symbol of the op after lowering. - */ -using TGlobalSymbol = String; - -/*! - * \brief Whether the op is overloaded for vector form. - */ -using TVectorizable = bool; - -} // namespace tir -} // namespace tvm -#endif // TVM_TIR_OP_ATTR_TYPES_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b928aec7bcf20..be1c567198d91 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1004,7 +1004,9 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \param dtype The data type * \return Expr a expression with dtype. */ -TVM_DLL PrimExpr TypeAnnotation(DataType dtype); +inline PrimExpr TypeAnnotation(DataType dtype) { + return tir::Call(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); +} // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index 8c3d34af174e0..fc8232053b5f7 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -98,8 +98,7 @@ def compile_cuda(code, (out, _) = proc.communicate() if proc.returncode != 0: - msg = code - msg += "\nCompilation error:\n" + msg = "Compilation error:\n" msg += py_str(out) raise RuntimeError(msg) diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index f93a943cd9cf3..e42ac6b378061 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -18,9 +18,8 @@ import tvm._ffi import tvm.runtime._ffi_api -from tvm.runtime import DataType -import tvm.tir -from tvm.tir.expr import Cast as _Cast, FloatImm as _FloatImm +from tvm.runtime import convert, DataType +from tvm.tir.expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm def register(type_name, type_code): @@ -136,7 +135,9 @@ def lower(op): if t.lanes > 1: dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): - return tvm.tir.call_pure_extern(dtype, extern_func_name, op.value) - return tvm.tir.call_pure_extern(dtype, extern_func_name, op.a, op.b) + return _Call(dtype, extern_func_name, convert([op.value]), + _Call.Extern) + return _Call(dtype, extern_func_name, convert([op.a, op.b]), + _Call.Extern) return lower diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py index 78da8a60d24b4..acb0efe0ea64b 100644 --- a/python/tvm/target/intrin.py +++ b/python/tvm/target/intrin.py @@ -83,14 +83,10 @@ def _rule_float_suffix(op): -------- register_intrin_rule : The registeration function for intrin rule. """ - name = op.name - assert name.startswith("tir.") - prefix = name[4:] - if op.dtype == "float32": - return call_pure_extern(op.dtype, "%sf" % prefix, *op.args) + return call_pure_extern(op.dtype, "%sf" % op.name, *op.args) if op.dtype == "float64": - return call_pure_extern(op.dtype, prefix, *op.args) + return call_pure_extern(op.dtype, op.name, *op.args) return op @@ -115,7 +111,7 @@ def _rule_float_direct(op): register_intrin_rule : The registeration function for intrin rule. """ if str(op.dtype).startswith("float"): - return call_pure_extern(op.dtype, op.op.name[4:], *op.args) + return call_pure_extern(op.dtype, op.name, *op.args) return None # opencl pattern for exp diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index a119c20754f47..dfbb185a7eb47 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -148,7 +148,7 @@ def likely(func_id, args): _internal_assert(args.__len__() == 1, \ "Only one expression can be likely") _internal_assert(func_id == "likely", "This function cannot be directly invoked!") - return call_pure_intrin(args[0].dtype, 'tir.likely', *args) + return call_pure_intrin(args[0].dtype, 'likely', *args) def max_num_threads(func_id, args): diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 386badf3e8aab..3b580efe2b628 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -30,7 +30,7 @@ import tvm._ffi from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const -from tvm.ir import PrimExpr, Op +from tvm.ir import PrimExpr import tvm.ir._ffi_api from . import generic as _generic from . import _ffi_api @@ -144,7 +144,7 @@ def __rxor__(self, other): def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.bitwise_not(self) + return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic) def __lt__(self, other): return _ffi_api._OpLT(self, other) @@ -968,9 +968,8 @@ class Call(PrimExprWithOp): dtype : str The return data type - op : Union[RelayExpr, str] - The function to be called, or the name - to the global tvm.Op + name : str + The name of the function args : list of Expr The input arguments to the call @@ -983,16 +982,9 @@ class Call(PrimExprWithOp): PureExtern = 2 Intrinsic = 4 PureIntrinsic = 5 - def __init__(self, dtype, op, args, call_type): - if isinstance(op, str): - if not op.startswith("tir."): - raise ValueError( - ("Cannot handle str op argument %s. This function only handles str " + - "argument with the tir namespace. If you are " + - "certain about the intrinsic name, pass in Op.get(name) instead") % op) - op = Op.get(op) - self.__init_handle_by_constructor__( - _ffi_api.Call, dtype, op, args, call_type) + def __init__(self, dtype, name, args, call_type): + self.__init_handle_by_constructor__( + _ffi_api.Call, dtype, name, args, call_type) @tvm._ffi.register_object("tir.Let") diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 089127c6f0ff0..47ba2e2c805c9 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -379,7 +379,7 @@ def likely(self, expr): expr : Expr The expression will likely tag. """ - return _expr.Call(expr.dtype, "tir.likely", [expr], + return _expr.Call(expr.dtype, "likely", [expr], _expr.Call.PureIntrinsic) def get(self): diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 6826241ac1a6e..929d422ccc431 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -18,10 +18,10 @@ """Operators used in TIR expression.""" import tvm._ffi from tvm.runtime import convert, const -from tvm.ir import Array, Op +from tvm.ir import Array from .buffer import Buffer -from .expr import Call, StringImm, Var, CommReducer +from .expr import Call, Var, CommReducer from . import _ffi_api @@ -29,9 +29,9 @@ def _pack_buffer(buf): """Build intrinsics that packs the buffer. """ assert buf.shape - shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, + shape = Call("handle", "tvm_stack_make_shape", buf.shape, Call.Intrinsic) - strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, + strides = Call("handle", "tvm_stack_make_shape", buf.strides, Call.Intrinsic) if buf.strides else 0 pack_args = [buf.data, shape, @@ -39,7 +39,7 @@ def _pack_buffer(buf): len(buf.shape), const(0, dtype=buf.dtype), buf.elem_offset] - return Call("handle", Op.get("tir.tvm_stack_make_array"), + return Call("handle", "tvm_stack_make_array", pack_args, Call.Intrinsic) def call_packed(*args): @@ -68,7 +68,7 @@ def call_packed(*args): """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] return Call( - "int32", Op.get("tir.tvm_call_packed"), call_args, Call.Intrinsic) + "int32", "tvm_call_packed", call_args, Call.Intrinsic) def call_pure_intrin(dtype, func_name, *args): @@ -145,7 +145,7 @@ def call_pure_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.PureExtern) + dtype, func_name, convert(args), Call.PureExtern) def call_extern(dtype, func_name, *args): @@ -168,7 +168,7 @@ def call_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), Call.Extern) + dtype, func_name, convert(args), Call.Extern) def call_llvm_intrin(dtype, name, *args): @@ -194,8 +194,7 @@ def call_llvm_intrin(dtype, name, *args): from tvm.target import codegen llvm_id = codegen.llvm_lookup_intrinsic_id(name) assert llvm_id != 0, "%s is not an LLVM intrinsic" % name - return call_pure_intrin(dtype, Op.get("tir.call_llvm_intrin"), - tvm.tir.const(llvm_id, 'uint32'), *args) + return call_pure_intrin(dtype, 'llvm_intrin', tvm.tir.const(llvm_id, 'uint32'), *args) def any(*args): @@ -279,7 +278,7 @@ def trace(args, trace_action="tvm.default_trace_action"): call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) return tvm.tir.Call( - args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args, tvm.tir.Call.Intrinsic) + args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic) @@ -328,7 +327,7 @@ def exp(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.exp", x) + return call_pure_intrin(x.dtype, "exp", x) def exp2(x): @@ -344,7 +343,7 @@ def exp2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.exp2", x) + return call_pure_intrin(x.dtype, "exp2", x) def exp10(x): @@ -360,7 +359,7 @@ def exp10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.exp10", x) + return call_pure_intrin(x.dtype, "exp10", x) def erf(x): @@ -376,7 +375,7 @@ def erf(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.erf", x) + return call_pure_intrin(x.dtype, "erf", x) def tanh(x): @@ -392,7 +391,7 @@ def tanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.tanh", x) + return call_pure_intrin(x.dtype, "tanh", x) def sigmoid(x): @@ -408,7 +407,7 @@ def sigmoid(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sigmoid", x) + return call_pure_intrin(x.dtype, "sigmoid", x) def log(x): @@ -424,7 +423,7 @@ def log(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log", x) + return call_pure_intrin(x.dtype, "log", x) def log2(x): @@ -440,7 +439,7 @@ def log2(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log2", x) + return call_pure_intrin(x.dtype, "log2", x) def log10(x): @@ -456,7 +455,7 @@ def log10(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log10", x) + return call_pure_intrin(x.dtype, "log10", x) def log1p(x): @@ -472,7 +471,7 @@ def log1p(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.log1p", x) + return call_pure_intrin(x.dtype, "log1p", x) def tan(x): @@ -488,7 +487,7 @@ def tan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.tan", x) + return call_pure_intrin(x.dtype, "tan", x) def cos(x): @@ -504,7 +503,7 @@ def cos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.cos", x) + return call_pure_intrin(x.dtype, "cos", x) def cosh(x): @@ -520,7 +519,7 @@ def cosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.cosh", x) + return call_pure_intrin(x.dtype, "cosh", x) def acos(x): @@ -536,7 +535,7 @@ def acos(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.acos", x) + return call_pure_intrin(x.dtype, "acos", x) def acosh(x): @@ -552,7 +551,7 @@ def acosh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.acosh", x) + return call_pure_intrin(x.dtype, "acosh", x) def sin(x): @@ -568,7 +567,7 @@ def sin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sin", x) + return call_pure_intrin(x.dtype, "sin", x) def sinh(x): @@ -584,7 +583,7 @@ def sinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sinh", x) + return call_pure_intrin(x.dtype, "sinh", x) def asin(x): @@ -600,7 +599,7 @@ def asin(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.asin", x) + return call_pure_intrin(x.dtype, "asin", x) def asinh(x): @@ -616,7 +615,7 @@ def asinh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.asinh", x) + return call_pure_intrin(x.dtype, "asinh", x) def atan(x): @@ -632,7 +631,7 @@ def atan(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.atan", x) + return call_pure_intrin(x.dtype, "atan", x) def atanh(x): @@ -648,7 +647,7 @@ def atanh(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.atanh", x) + return call_pure_intrin(x.dtype, "atanh", x) def atan2(x1, x2): @@ -667,7 +666,7 @@ def atan2(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.atan2", x1, x2) + return call_pure_intrin(x1.dtype, "atan2", x1, x2) def sqrt(x): @@ -683,7 +682,7 @@ def sqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.sqrt", x) + return call_pure_intrin(x.dtype, "sqrt", x) def rsqrt(x): @@ -699,7 +698,7 @@ def rsqrt(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.rsqrt", x) + return call_pure_intrin(x.dtype, "rsqrt", x) def floor(x): @@ -824,7 +823,7 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.nextafter", x1, x2) + return call_pure_intrin(x1.dtype, "nextafter", x1, x2) def hypot(x1, x2): @@ -843,7 +842,7 @@ def hypot(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.hypot", x1, x2) + return call_pure_intrin(x1.dtype, "hypot", x1, x2) def copysign(x1, x2): @@ -862,7 +861,7 @@ def copysign(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.copysign", x1, x2) + return call_pure_intrin(x1.dtype, "copysign", x1, x2) def ldexp(x1, x2): @@ -881,7 +880,7 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - return call_pure_intrin(x1.dtype, "tir.ldexp", x1, x2) + return call_pure_intrin(x1.dtype, "ldexp", x1, x2) def isnan(x): @@ -964,7 +963,7 @@ def popcount(x): y : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.popcount", x) + return call_pure_intrin(x.dtype, "popcount", x) def fmod(x, y): """Return the remainder of x divided by y with the same sign as x. @@ -981,7 +980,7 @@ def fmod(x, y): z : PrimExpr The result. """ - return call_pure_intrin(x.dtype, "tir.fmod", x, y) + return call_pure_intrin(x.dtype, "fmod", x, y) def if_then_else(cond, t, f): diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8c90249f4f174..c33990cd1f4fd 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include @@ -285,10 +284,9 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr_(const CallNode* op) final { // only special handle >> and & which can be // used for index calculation. - - if (op->op.same_as(tir::builtin::shift_right())) { + if (op->is_intrinsic(CallNode::shift_right)) { return VisitRightShift(op); - } else if (op->op.same_as(tir::builtin::bitwise_and())) { + } else if (op->is_intrinsic(CallNode::bitwise_and)) { return VisitBitwiseAnd(op); } else { return Everything(op->dtype); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index c367d0c9f9d84..84e2093dcf98f 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -56,10 +56,8 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr real_condition = condition; - static auto op_likely = Op::Get("tir.likely"); - if (auto call = condition.as()) { - if (call->op.same_as(op_likely)) { + if (call->is_intrinsic(CallNode::likely)) { real_condition = call->args[0]; } } @@ -124,8 +122,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - static auto op_if_then_else = Op::Get("tir.if_then_else"); - if (op->op.same_as(op_if_then_else)) { + if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr true_value, false_value; { @@ -146,7 +143,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { false_value.same_as(op->args[2])) { return GetRef(op); } else { - return Call(op->dtype, op->op, {cond, true_value, false_value}, op->call_type); + return Call(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 108f08c4f78f4..3457674d4ed30 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -23,7 +23,6 @@ */ #include #include -#include #include #include @@ -204,7 +203,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor> which can be // used for index calculation. - if (op->op.same_as(tir::builtin::shift_right())) { + if (op->is_intrinsic(CallNode::shift_right)) { return VisitRightShift(op); } else { return Everything(); diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index de8425146bbfb..ff01941e4acf3 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -66,7 +66,6 @@ #define TVM_ARITH_PATTERN_MATCH_H_ #include -#include #include #include @@ -656,7 +655,7 @@ class PCallExpr : public Pattern> { bool Match_(const ObjectRef& node) const { if (const tir::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; - if (!ptr->op.same_as(Op::GetOp())) return false; + if (ptr->name != Op::kName) return false; detail::PCallExprMatchFunctor fmatch(ptr); detail::tuple_for_each(fmatch, args_); return fmatch.matched_; @@ -676,45 +675,45 @@ class PCallExpr : public Pattern> { }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ struct OpName { \ static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ + return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ } \ - static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ + static constexpr const char* kName = IntrinStr; \ }; \ template \ inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ return PCallExpr(a.derived(), b.derived()); \ } -TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left); -TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right); -TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and); -TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); -TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); +TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); +TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, "shift_right"); +TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, "bitwise_and"); +TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); +TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::Call(args[0].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); \ - } \ - static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } -TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); +TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::Call(args[1].dtype(), GetOp(), args, tir::CallNode::PureIntrinsic); + return tir::Call(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); } - static const Op& GetOp() { return tir::builtin::if_then_else(); } + static constexpr const char* kName = "tvm_if_then_else"; }; /*! diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 6758c9b569a81..4887ef0ee47d2 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -25,7 +25,6 @@ #include "rewrite_simplify.h" #include -#include #include #include @@ -1509,22 +1508,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; - - if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) { + if (op->is_intrinsic(CallNode::likely) && is_const(op->args[0])) { return op->args[0]; - } else if (op->op.same_as(tir::builtin::shift_right())) { + } else if (op->is_intrinsic(CallNode::shift_right)) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] >> op->args[1]; } - } else if (op->op.same_as(tir::builtin::shift_left())) { + } else if (op->is_intrinsic(CallNode::bitwise_and)) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] & op->args[1]; } } ExprDeepEqual expr_equal; - if (op->op.same_as(tir::builtin::likely())) { + if (op->is_intrinsic(CallNode::likely)) { for (const auto& constraint : literal_constraints_) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (expr_equal(constraint, op->args[0])) { diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 0d5d654c3f6e2..e08f39f8135db 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -23,7 +23,6 @@ #include "codegen_hybrid.h" #include -#include #include #include @@ -217,43 +216,29 @@ void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { os << "]"; } void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->op.same_as(builtin::bitwise_and())) { + if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsitc(op, "&", os, this); - } else if (op->op.same_as(builtin::bitwise_xor())) { + } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsitc(op, "^", os, this); - } else if (op->op.same_as(builtin::bitwise_or())) { + } else if (op->is_intrinsic(CallNode::bitwise_or)) { PrintBinaryIntrinsitc(op, "|", os, this); - } else if (op->op.same_as(builtin::shift_left())) { + } else if (op->is_intrinsic(CallNode::shift_left)) { PrintBinaryIntrinsitc(op, "<<", os, this); - } else if (op->op.same_as(builtin::shift_right())) { + } else if (op->is_intrinsic(CallNode::shift_right)) { PrintBinaryIntrinsitc(op, ">>", os, this); - } else if (op->op.same_as(builtin::bitwise_not())) { + } else if (op->is_intrinsic(CallNode::bitwise_not)) { CHECK_EQ(op->args.size(), 1U); os << "(~"; PrintExpr(op->args[0], os); os << ')'; - } else if (op->op.same_as(builtin::if_then_else())) { + } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { PrintExpr(op->args[1], os); os << " if "; PrintExpr(op->args[0], os); os << " else "; PrintExpr(op->args[2], os); - } else if (op->op.same_as(builtin::call_extern())) { - StringImm fname = Downcast(op->args[0]); - os << fname << "("; - for (size_t i = 1; i < op->args.size(); i++) { - PrintExpr(op->args[i], os); - if (i < op->args.size() - 1) { - os << ", "; - } - } - os << ")"; } else { - auto* ptr_op = op->op.as(); - CHECK(ptr_op != nullptr); - std::string name = ptr_op->name; - CHECK_EQ(name.compare(0, 4, "tir."), 0); - os << name.substr(4) << "("; + os << op->name << "("; for (size_t i = 0; i < op->args.size(); i++) { PrintExpr(op->args[i], os); if (i < op->args.size() - 1) { diff --git a/src/ir/op.cc b/src/ir/op.cc index 45c31963695cd..63d223050ff5e 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -42,7 +42,7 @@ using OpRegistry = AttrRegistry; // find operator by name const Op& Op::Get(const String& name) { const OpRegEntry* reg = OpRegistry::Global()->Get(name); - CHECK(reg != nullptr) << "AttributeError: Operator " << name << " is not registered"; + CHECK(reg != nullptr) << "Operator " << name << " is not registered"; return reg->op(); } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 233a73954c93b..29927379f17d7 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -345,14 +345,7 @@ inline const char* CallType2String(CallNode::CallType t) { Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { Doc doc; - if (auto* ptr_op = op->op.as()) { - doc << "@" << Doc::Text(ptr_op->name) << "("; - } else { - // TODO(bohan): Print out the name by he global var in the module. - auto* op_gvar = op->op.as(); - CHECK(op_gvar != nullptr); - doc << "@" << Doc::Text(op_gvar->name_hint) << "("; - } + doc << "@" << Doc::Text(op->name) << "("; std::vector args; for (const auto& arg : op->args) { args.push_back(Print(arg)); @@ -377,7 +370,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; - doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body); + doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body); return doc; } @@ -396,8 +389,8 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { Doc doc; - doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine() - << Print(op->body); + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" + << PrintBody(op->body); return doc; } diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 5f5876212b62d..35bbb234dbc15 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -121,7 +121,7 @@ inline bool IsAtomic(const Expr& e) { * \return compiler_begin op */ inline const Op& CompilerBeginOp() { - static auto op = Op::Get("annotation.compiler_begin"); + static Op op = Op::Get("annotation.compiler_begin"); return op; } @@ -131,7 +131,7 @@ inline const Op& CompilerBeginOp() { * \return compiler_end op */ inline const Op& CompilerEndOp() { - static auto op = Op::Get("annotation.compiler_end"); + static Op op = Op::Get("annotation.compiler_end"); return op; } diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 36e553900d007..5a23e83af2198 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -25,7 +25,6 @@ #define TVM_TARGET_INTRIN_RULE_H_ #include -#include #include #include @@ -59,20 +58,9 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); - // Use string based dispatch to extern for backward compact - // TODO(tvm-team) replace once the new dispatching system is inplace. - const OpNode* op = call->op.as(); - CHECK(op != nullptr); - std::string name = op->name; - CHECK_EQ(name.substr(0, 4), "tir."); - name = T()(call->dtype, name.substr(4)); - + std::string name = T()(call->dtype, call->name); if (name.length() != 0) { - Array new_args = {StringImm(name)}; - for (auto arg : call->args) { - new_args.push_back(arg); - } - *rv = Call(call->dtype, tir::builtin::call_extern(), new_args, CallNode::PureExtern); + *rv = Call(call->dtype, name, call->args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 13ce59d54b82b..991d4730a136d 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -46,7 +46,7 @@ class CodeGenARM final : public CodeGenCPU { }; llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { - if (op->op.same_as(builtin_call_llvm_intrin_)) { + if (op->is_intrinsic("llvm_intrin")) { llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); @@ -70,7 +70,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -94,16 +94,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = - tir::Call(uint8_type, builtin_call_llvm_intrin_, vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = tir::Call(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = - tir::Call(uint16_type, builtin_call_llvm_intrin_, vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = tir::Call(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -113,8 +111,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = - tir::Call(uint32_type, builtin_call_llvm_intrin_, vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = tir::Call(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -124,7 +121,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::Call(call->dtype, builtin_call_llvm_intrin_, vcnt64_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f855dd5b83b24..6ad050ace9a30 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -226,7 +226,7 @@ std::unique_ptr CodeGenCPU::Finish() { } llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind) { - if (kind < builtin::kArrKindBound_) { + if (kind < intrinsic::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); } else { @@ -234,40 +234,40 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } } switch (kind) { - case builtin::kArrAddr: { + case intrinsic::kArrAddr: { return builder_->CreateInBoundsGEP(buf, index); } - case builtin::kArrData: { + case intrinsic::kArrData: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)}); } - case builtin::kArrShape: { + case intrinsic::kArrShape: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)}); } - case builtin::kArrStrides: { + case intrinsic::kArrStrides: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)}); } - case builtin::kArrNDim: { + case intrinsic::kArrNDim: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); } - case builtin::kArrTypeCode: { + case intrinsic::kArrTypeCode: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); } - case builtin::kArrTypeBits: { + case intrinsic::kArrTypeBits: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); } - case builtin::kArrTypeLanes: { + case intrinsic::kArrTypeLanes: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); } - case builtin::kArrByteOffset: { + case intrinsic::kArrByteOffset: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); } - case builtin::kArrDeviceId: { + case intrinsic::kArrDeviceId: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); } - case builtin::kArrDeviceType: { + case intrinsic::kArrDeviceType: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); } - case builtin::kTVMValueContent: { + case intrinsic::kTVMValueContent: { CHECK_EQ(t.lanes(), 1); CHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { @@ -289,23 +289,23 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } } -llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { - std::vector arg_values; - for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { - arg_values.push_back(MakeValue(args[i])); +llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { + std::vector arg_values(op->args.size()); + for (size_t i = 0; i < op->args.size(); ++i) { + arg_values[i] = MakeValue(op->args[i]); } std::vector arg_types; for (llvm::Value* v : arg_values) { arg_types.push_back(v->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_types, false); // Check if it is available in global function table as injected function. - auto it = gv_func_map_.find(global_symbol); + auto it = gv_func_map_.find(op->name); if (it != gv_func_map_.end()) { if (it->second == nullptr) { - gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); - it = gv_func_map_.find(global_symbol); + gv_func_map_[op->name] = InitContextPtr(ftype->getPointerTo(), "__" + op->name); + it = gv_func_map_.find(op->name); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second)); @@ -314,10 +314,10 @@ llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, #endif return builder_->CreateCall(ext_callee, arg_values); } else { - llvm::Function* f = module_->getFunction(global_symbol); + llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - global_symbol.operator llvm::StringRef(), module_.get()); + op->name.operator llvm::StringRef(), module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -773,38 +773,38 @@ void CodeGenCPU::AddStartupFunction() { } llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { - if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { return CreateCallPacked(op); - } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { + } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed_lowered)) { return CreateCallTracePacked(op); - } else if (op->op.same_as(builtin::tvm_static_handle())) { + } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) { return CreateStaticHandle(); - } else if (op->op.same_as(builtin::tvm_throw_last_error())) { + } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { builder_->CreateRet(ConstInt32(-1)); return ConstInt32(-1); - } else if (op->op.same_as(builtin::tvm_struct_get())) { + } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; llvm::Value* ref = this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - if (kind == builtin::kArrAddr) { + if (kind == intrinsic::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); } else { return builder_->CreateLoad(ref); } - } else if (op->op.same_as(builtin::tvm_struct_set())) { + } else if (op->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); - CHECK(kind != builtin::kArrAddr); + CHECK(kind != intrinsic::kArrAddr); if (value->getType()->isPointerTy()) { value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); } builder_->CreateStore(value, ref); return ConstInt32(0); - } else if (op->op.same_as(builtin::tvm_stack_alloca())) { + } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index fdeab41307822..7a14b8fdc959b 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -47,8 +47,7 @@ class CodeGenCPU : public CodeGenLLVM { void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const ForNode* op) override; llvm::Value* CreateIntrinsic(const CallNode* op) override; - llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg) override; + llvm::Value* CreateCallExtern(const CallNode* op) override; protected: void AddStartupFunction() final; @@ -123,7 +122,7 @@ class CodeGenCPU : public CodeGenLLVM { llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_launch_{nullptr}; llvm::GlobalVariable* gv_tvm_parallel_barrier_{nullptr}; - std::unordered_map gv_func_map_; + std::unordered_map gv_func_map_; // context for direct dynamic lookup llvm::Function* f_tvm_func_call_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr}; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 49f14c31d07ff..85e3de5844fd5 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -653,19 +653,19 @@ llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { return it->second; } -llvm::Value* CodeGenLLVM::CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg) { +llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { std::vector arg_value; std::vector arg_type; - for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { - arg_value.push_back(MakeValue(args[i])); + for (size_t i = 0; i < op->args.size(); ++i) { + arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_type, false); - llvm::Function* f = module_->getFunction(global_symbol); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); + llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - global_symbol.operator llvm::StringRef(), module_.get()); + op->name.operator llvm::StringRef(), module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; @@ -738,7 +738,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type } llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { - if (op->op.same_as(builtin_call_llvm_intrin_)) { + if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); int64_t num_signature = Downcast(op->args[1])->value; @@ -759,29 +759,30 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // type as LLVM. llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) : llvm::Type::getVoidTy(*ctx_); + llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " << llvm::Intrinsic::getName(id, {}); return builder_->CreateCall(f, arg_value); - } else if (op->op.same_as(builtin::bitwise_and())) { + } else if (op->is_intrinsic(CallNode::bitwise_and)) { return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->op.same_as(builtin::bitwise_or())) { + } else if (op->is_intrinsic(CallNode::bitwise_or)) { return builder_->CreateOr(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->op.same_as(builtin::bitwise_not())) { + } else if (op->is_intrinsic(CallNode::bitwise_not)) { return builder_->CreateNot(MakeValue(op->args[0])); - } else if (op->op.same_as(builtin::bitwise_xor())) { + } else if (op->is_intrinsic(CallNode::bitwise_xor)) { return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->op.same_as(builtin::shift_left())) { + } else if (op->is_intrinsic(CallNode::shift_left)) { return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); - } else if (op->op.same_as(builtin::shift_right())) { + } else if (op->is_intrinsic(CallNode::shift_right)) { if (op->args[0].dtype().is_int()) { return builder_->CreateAShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } else { return builder_->CreateLShr(MakeValue(op->args[0]), MakeValue(op->args[1])); } - } else if (op->op.same_as(builtin::tvm_storage_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); - } else if (op->op.same_as(builtin::address_of())) { + } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); const RampNode* r = l->index.as(); @@ -796,17 +797,17 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); - } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { + } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); - } else if (op->op.same_as(builtin::isnullptr())) { + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { return builder_->CreateIsNull(MakeValue(op->args[0])); - } else if (op->op.same_as(builtin::large_uint_imm())) { + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); - } else if (op->op.same_as(builtin::if_then_else())) { + } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); @@ -826,22 +827,22 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(then_value, then_value_block); value->addIncoming(else_value, else_value_block); return value; - } else if (op->op.same_as(builtin::reinterpret())) { + } else if (op->is_intrinsic(CallNode::reinterpret)) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); - } else if (op->op.same_as(builtin::isnan())) { + } else if (op->is_intrinsic(CallNode::isnan)) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); return builder_->CreateFCmpUNO(a, a); - } else if (op->op.same_as(builtin::vectorlow())) { + } else if (op->is_intrinsic("vectorlow")) { llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); return CreateVecSlice(v, 0, l / 2); - } else if (op->op.same_as(builtin::vectorhigh())) { + } else if (op->is_intrinsic("vectorhigh")) { llvm::Value* v = MakeValue(op->args[0]); int l = llvm::cast(v->getType())->getNumElements(); return CreateVecSlice(v, l / 2, l / 2); - } else if (op->op.same_as(builtin::vectorcombine())) { + } else if (op->is_intrinsic("vectorcombine")) { llvm::Value* v0 = MakeValue(op->args[0]); llvm::Value* v1 = MakeValue(op->args[1]); int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; @@ -855,7 +856,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } return builder_->CreateShuffleVector(v0, v1, indices); } else { - LOG(FATAL) << "unknown intrinsic " << op->op; + LOG(FATAL) << "unknown intrinsic " << op->name; return nullptr; } } @@ -1075,24 +1076,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (auto* ptr_op = op->op.as()) { - auto call_op = GetRef(ptr_op); - if (op->op.same_as(builtin_call_extern_)) { - // call extern intrinsic - CHECK_GE(op->args.size(), 1U); - auto global_symbol = Downcast(op->args[0]); - return this->CreateCallExtern(GetType(GetRef(op)), global_symbol->value, op->args, - true); - } else if (op_attr_global_symbol_.count(call_op)) { - // call extern if the op itself have a global symbol. - return this->CreateCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], - op->args, false); - } else { - return CreateIntrinsic(op); - } + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + return CreateIntrinsic(op); + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { + return CreateCallExtern(op); } else { - CHECK(op->op.as()); - LOG(FATAL) << "Do not yet support cross function call"; + LOG(FATAL) << "Unknown call type " + << "name= " << op->name << " call_type= " << op->call_type; return nullptr; } } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 2bfe047038b06..0bca2a169ba43 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -32,7 +32,6 @@ #include #include #include -#include #include #include @@ -176,9 +175,7 @@ class CodeGenLLVM : public ExprFunctor, // create intrinstic given call virtual llvm::Value* CreateIntrinsic(const CallNode* op); // create extern function call - // skip first arg mode used for call extern intrinsic. - virtual llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, - const Array& args, bool skip_first_arg); + virtual llvm::Value* CreateCallExtern(const CallNode* op); // Get the corresponding thread index virtual llvm::Value* GetThreadIndex(const IterVar& iv); // Get the corresponding thread index @@ -322,11 +319,6 @@ class CodeGenLLVM : public ExprFunctor, std::unordered_set alias_var_set_; // set of volatile buffer. std::unordered_set volatile_buf_; - // Cache potential common path ops to slightly improve lookup time. - // global symbol table. - OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); - const Op& builtin_call_extern_ = builtin::call_extern(); - const Op& builtin_call_llvm_intrin_ = builtin::call_llvm_intrin(); /*! \brief Helper struct for debug infos. */ struct DebugInfo { std::unique_ptr di_builder_; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 71c8e78030c20..bc47ce1b10149 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -197,11 +197,11 @@ static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; int offset = 0; - if (op->op.same_as(builtin::tvm_warp_shuffle())) { + if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { offset = 0; - } else if (op->op.same_as(builtin::tvm_warp_shuffle_up())) { + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { offset = 2; - } else if (op->op.same_as(builtin::tvm_warp_shuffle_down())) { + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { offset = 4; } else { return false; @@ -226,7 +226,7 @@ llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { llvm::Type* return_type = arg_type[0]; llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); return builder_->CreateCall(func, arg_value); - } else if (op->op.same_as(builtin::tvm_warp_activemask())) { + } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { // Only nvptx target may keep this intrinsic at this point. // PTX assembly: asm "activemask.b32 r1;" auto fty = llvm::FunctionType::get(t_int32_, false); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 5d269fa4d513c..edffda287c7b3 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -89,7 +89,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), + MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, tir::CallNode::PureIntrinsic)), MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), @@ -105,7 +105,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { return CallVectorIntrin( ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), + {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, tir::CallNode::PureIntrinsic))}); } #endif diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index abf350e2208a6..8804b1e45a6f1 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -39,8 +39,6 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -// TODO(tvm-team): migrate the legalization transformations as a separate -// set of rules in TIR that can be shared across backends. TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { using tir::make_const; @@ -50,7 +48,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") CHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = exp(x * ln10); + PrimExpr ret = tir::Call(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); *rv = ret; }); @@ -99,8 +97,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_two = make_const(x.dtype(), -2); - PrimExpr exp_neg2x = exp(neg_two * x); - PrimExpr exp_pos2x = exp(two * x); + PrimExpr exp_neg2x = tir::Call(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_pos2x = tir::Call(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); @@ -118,7 +116,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr tan_x = sin(x) / cos(x); + PrimExpr sin_x = tir::Call(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::Call(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); @@ -135,8 +135,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = exp(neg_one * x); - PrimExpr exp_posx = exp(x); + PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); PrimExpr ret = (exp_posx + exp_negx) / two; *rv = ret; }); @@ -154,8 +154,8 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = exp(neg_one * x); - PrimExpr exp_posx = exp(x); + PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); PrimExpr ret = (exp_posx - exp_negx) / two; *rv = ret; }); diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index cc9437d25b7e1..5613621d77fbc 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -27,7 +27,6 @@ #include #include -#include #include #include @@ -50,8 +49,7 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = - tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); } template @@ -66,7 +64,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs, tir::CallNode::Intrinsic); + *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index a0ffe11da27a6..49c2224932a5d 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -23,9 +23,7 @@ #ifdef TVM_LLVM_VERSION #include -#include #include -#include #include @@ -38,21 +36,10 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { const CallNode* call = e.as(); CHECK(call != nullptr); CHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) << "Only support float32 or float64."; - - const OpNode* op = call->op.as(); - CHECK(op != nullptr); - std::string name = op->name; - CHECK_EQ(name.substr(0, 4), "tir."); - std::ostringstream intrinsic_name; - intrinsic_name << "__nv_" << name.substr(4); + intrinsic_name << "__nv_" << call->name; if (call->dtype.bits() == 32) intrinsic_name << "f"; - - Array new_args = {StringImm(intrinsic_name.str())}; - for (auto arg : call->args) { - new_args.push_back(arg); - } - *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); + *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 07520ae08cc88..3a2b8ac77f82d 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -23,7 +23,6 @@ #ifdef TVM_LLVM_VERSION #include -#include #include #include @@ -37,21 +36,9 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { using namespace tir; const CallNode* call = e.as(); CHECK(call != nullptr); - - const OpNode* op = call->op.as(); - CHECK(op != nullptr); - std::string name = op->name; - CHECK_EQ(name.substr(0, 4), "tir."); - std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); - - Array new_args = {StringImm(intrinsic_name.str())}; - for (auto arg : call->args) { - new_args.push_back(arg); - } - - *rv = Call(call->dtype, builtin::call_extern(), new_args, CallNode::PureExtern); + intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); + *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { @@ -66,30 +53,29 @@ inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { // get own lane in self (__lane_id) PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); PrimExpr zero = tir::make_zero(DataType::Int(32)); - PrimExpr lo = Call(DataType::Int(32), builtin::call_extern(), - {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}, CallNode::PureExtern); - PrimExpr self = Call(DataType::Int(32), builtin::call_extern(), - {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo}, CallNode::PureExtern); + PrimExpr lo = + Call(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, CallNode::PureExtern); + PrimExpr self = + Call(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, CallNode::PureExtern); // compute lane to get from PrimExpr width = call->args[3]; PrimExpr index; - if (call->op.same_as(builtin::tvm_warp_shuffle())) { + if (call->name == "tvm_warp_shuffle") { PrimExpr src_lane = call->args[2]; index = src_lane + (self & ~(width - 1)); - } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) { + } else if (call->name == "tvm_warp_shuffle_up") { PrimExpr delta = call->args[2]; index = self - delta; index = Select(index < (self & ~(width - 1)), self, index); } else { - CHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); + CHECK_EQ(call->name, "tvm_warp_shuffle_down"); PrimExpr delta = call->args[2]; index = self + delta; index = Select((self & (width - 1)) + delta >= width, self, index); } PrimExpr res = - Call(var.dtype(), builtin::call_extern(), - {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var}, CallNode::PureExtern); + Call(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, CallNode::PureExtern); *rv = res; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ffeaba06d7010..9255d7c80c46d 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -223,12 +223,12 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i // Print a reference expression to a buffer. std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { - if (kind < builtin::kArrKindBound_) { + if (kind < intrinsic::kArrKindBound_) { std::ostringstream os; os << "(((DLTensor*)"; this->PrintExpr(buffer, os); os << ")"; - if (kind == builtin::kArrAddr) { + if (kind == intrinsic::kArrAddr) { os << " + "; this->PrintExpr(index, os); os << ")"; @@ -239,34 +239,34 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << "]."; // other case: get fields. switch (kind) { - case builtin::kArrData: + case intrinsic::kArrData: os << "data"; break; - case builtin::kArrShape: + case intrinsic::kArrShape: os << "shape"; break; - case builtin::kArrStrides: + case intrinsic::kArrStrides: os << "strides"; break; - case builtin::kArrNDim: + case intrinsic::kArrNDim: os << "ndim"; break; - case builtin::kArrTypeCode: + case intrinsic::kArrTypeCode: os << "dtype.code"; break; - case builtin::kArrTypeBits: + case intrinsic::kArrTypeBits: os << "dtype.bits"; break; - case builtin::kArrByteOffset: + case intrinsic::kArrByteOffset: os << "byte_offset"; break; - case builtin::kArrTypeLanes: + case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break; - case builtin::kArrDeviceId: + case intrinsic::kArrDeviceId: os << "ctx.device_id"; break; - case builtin::kArrDeviceType: + case intrinsic::kArrDeviceType: os << "ctx.device_type"; break; default: @@ -275,7 +275,7 @@ std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const Pri os << ')'; return os.str(); } else { - CHECK_LT(kind, builtin::kTVMValueKindBound_); + CHECK_LT(kind, intrinsic::kTVMValueKindBound_); std::ostringstream os; os << "(((TVMValue*)"; this->PrintExpr(buffer, os); @@ -559,94 +559,80 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->a, os); } -void CodeGenC::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) { // NOLINT(*) - os << global_symbol << "("; - for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { - this->PrintExpr(args[i], os); - if (i < args.size() - 1) { - os << ", "; - } - } - os << ")"; -} - void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (auto* ptr_op = op->op.as()) { - auto call_op = GetRef(ptr_op); - - if (op->op.same_as(builtin_call_extern_)) { - CHECK_GE(op->args.size(), 1U); - auto func = Downcast(op->args[0]); - this->PrintCallExtern(GetType(GetRef(op)), func->value, op->args, true, os); - } else if (op_attr_global_symbol_.count(call_op)) { - // call extern if the op itself have a global symbol. - this->PrintCallExtern(GetType(GetRef(op)), op_attr_global_symbol_[call_op], - op->args, false, os); - } else if (op->op.same_as(builtin::bitwise_and())) { - PrintBinaryIntrinsic(op, " & ", os, this); - } else if (op->op.same_as(builtin::large_uint_imm())) { - CHECK_EQ(op->args.size(), 2U); - uint64_t low = static_cast(Downcast(op->args[0])->value); - uint64_t high = static_cast(Downcast(op->args[1])->value); - uint64_t val = (high << 32U) | low; - PrintUIntConst(op->dtype, val, os, this); - } else if (op->op.same_as(builtin::bitwise_xor())) { - PrintBinaryIntrinsic(op, " ^ ", os, this); - } else if (op->op.same_as(builtin::bitwise_or())) { - PrintBinaryIntrinsic(op, " | ", os, this); - } else if (op->op.same_as(builtin::bitwise_not())) { - CHECK_EQ(op->args.size(), 1U); - os << "(~"; - this->PrintExpr(op->args[0], os); - os << ')'; - } else if (op->op.same_as(builtin::shift_left())) { - PrintBinaryIntrinsic(op, " << ", os, this); - } else if (op->op.same_as(builtin::shift_right())) { - PrintBinaryIntrinsic(op, " >> ", os, this); - } else if (op->op.same_as(builtin::if_then_else())) { - os << "("; - PrintExpr(op->args[0], os); - os << " ? "; - PrintExpr(op->args[1], os); - os << " : "; - PrintExpr(op->args[2], os); - os << ")"; - } else if (op->op.same_as(builtin::address_of())) { - const LoadNode* l = op->args[0].as(); - CHECK(op->args.size() == 1 && l); - os << "(("; - this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; - this->PrintExpr(l->index, os); - os << ')'; - } else if (op->op.same_as(builtin::tvm_struct_get())) { - CHECK_EQ(op->args.size(), 3U); - os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); - } else if (op->op.same_as(builtin::isnullptr())) { - CHECK_EQ(op->args.size(), 1U); - os << "("; - this->PrintExpr(op->args[0], os); - os << " == NULL)"; - } else if (op->op.same_as(builtin::reinterpret())) { - // generate (*( TYPE *)(&(ARG))) - os << "(*("; - this->PrintType(op->dtype, os); - os << " *)(&("; - this->PrintExpr(op->args[0], os); - os << ")))"; - } else if (op->op.same_as(builtin::isnan())) { - os << "("; - this->PrintExpr(op->args[0], os); - os << " != "; - this->PrintExpr(op->args[0], os); - os << ")"; - } else { - LOG(FATAL) << "Unresolved call " << op->op; + if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { + os << op->name << "("; + for (size_t i = 0; i < op->args.size(); i++) { + this->PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } } + os << ")"; + } else if (op->is_intrinsic(CallNode::bitwise_and)) { + PrintBinaryIntrinsic(op, " & ", os, this); + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { + CHECK_EQ(op->args.size(), 2U); + uint64_t low = static_cast(Downcast(op->args[0])->value); + uint64_t high = static_cast(Downcast(op->args[1])->value); + uint64_t val = (high << 32U) | low; + PrintUIntConst(op->dtype, val, os, this); + } else if (op->is_intrinsic(CallNode::bitwise_xor)) { + PrintBinaryIntrinsic(op, " ^ ", os, this); + } else if (op->is_intrinsic(CallNode::bitwise_or)) { + PrintBinaryIntrinsic(op, " | ", os, this); + } else if (op->is_intrinsic(CallNode::bitwise_not)) { + CHECK_EQ(op->args.size(), 1U); + os << "(~"; + this->PrintExpr(op->args[0], os); + os << ')'; + } else if (op->is_intrinsic(CallNode::shift_left)) { + PrintBinaryIntrinsic(op, " << ", os, this); + } else if (op->is_intrinsic(CallNode::shift_right)) { + PrintBinaryIntrinsic(op, " >> ", os, this); + } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { + os << "("; + PrintExpr(op->args[0], os); + os << " ? "; + PrintExpr(op->args[1], os); + os << " : "; + PrintExpr(op->args[2], os); + os << ")"; + } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { + const LoadNode* l = op->args[0].as(); + CHECK(op->args.size() == 1 && l); + os << "(("; + this->PrintType(l->dtype.element_of(), os); + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; + this->PrintExpr(l->index, os); + os << ')'; + } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { + CHECK_EQ(op->args.size(), 3U); + os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { + CHECK_EQ(op->args.size(), 1U); + os << "("; + this->PrintExpr(op->args[0], os); + os << " == NULL)"; + } else if (op->is_intrinsic(CallNode::reinterpret)) { + // generate (*( TYPE *)(&(ARG))) + os << "(*("; + this->PrintType(op->dtype, os); + os << " *)(&("; + this->PrintExpr(op->args[0], os); + os << ")))"; + } else if (op->is_intrinsic(CallNode::isnan)) { + os << "("; + this->PrintExpr(op->args[0], os); + os << " != "; + this->PrintExpr(op->args[0], os); + os << ")"; } else { - CHECK(op->op.as()); - LOG(FATAL) << "Do not yet support cross function call"; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + } else { + LOG(FATAL) << "Unresolved call type " << op->call_type; + } } } @@ -917,10 +903,10 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); if (call) { - if (call->op.same_as(builtin::tvm_storage_sync())) { + if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { this->PrintStorageSync(call); return; - } else if (call->op.same_as(builtin::tvm_struct_set())) { + } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 9346f87cb3bbc..309eb06816076 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,13 +24,10 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ -#include #include #include -#include #include #include -#include #include #include @@ -222,16 +219,6 @@ class CodeGenC : public ExprFunctor, */ virtual bool IsScopePartOfType() const { return true; } - /*! - * \brief Print external function call. - * \param ret_type The return type. - * \param global_symbol The symbolc of the target function. - * \param args The arguments to the function. - * \param skip_first_arg Whether to skip the first arguments. - * \param os The output stream. - */ - virtual void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os); // NOLINT(*) /*! * \brief If buffer is allocated as type t. * \param buf_var The buffer variable. @@ -258,10 +245,6 @@ class CodeGenC : public ExprFunctor, std::unordered_map alloc_storage_scope_; /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; - /*! \brief Record of ops that have pre-defined global symbol. */ - OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); - // cache commonly used ops - const Op& builtin_call_extern_ = builtin::call_extern(); private: /*! \brief whether to print in SSA form */ diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 839962a8c733f..b11b3d8fc5f98 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -175,7 +175,7 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar } void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->op.same_as(builtin::tvm_stack_alloca())) { + if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { std::string stack_name = GetUniqueName("stack"); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); @@ -197,7 +197,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT this->PrintIndent(); this->stream << "TVMValue " << stack_name << "[" << size << "];\n"; os << stack_name; - } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { const StringImmNode* s = op->args[0].as(); CHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; int64_t begin = op->args[3].as()->value; @@ -216,7 +216,7 @@ void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } this->PrintGetFuncFromBackend(func_name, packed_func_name); this->PrintFuncCall(packed_func_name, num_args); - } else if (op->op.same_as(builtin::tvm_throw_last_error())) { + } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { this->PrintIndent(); this->stream << "return -1;\n"; } else { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ae5e40acd8f5b..cf7a74f1dcc05 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -429,71 +429,15 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } -void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) { // NOLINT(*) - DataType ret_dtype = GetRuntimeDataType(ret_type); - if (ret_dtype.is_vector()) { - // - // Emit an unsupported vector call - // - // v = intrin_f((float4*)A[0], (float4*)B[0]) - // - // as - // - // float4 __ret; - // { - // float4 __arg0 = ((float4*)A)[0]; - // float4 __arg1 = ((float4*)B)[0]; - // __ret.x = intrin_f(__arg0.x, __arg1.x); - // __ret.y = intrin_f(__arg0.y, __arg1.y); - // __ret.z = intrin_f(__arg0.z, __arg1.z); - // __ret.w = intrin_f(__arg0.w, __arg1.w); - // } - // v = __ret; - // - // Declare the result vector. - std::string sret = GetUniqueName("_"); - this->PrintIndent(); - this->PrintType(ret_dtype, stream); - stream << ' ' << sret << ";\n"; - { - // Load arguments. - std::vector sargs; - size_t arg_begin = static_cast(skip_first_arg); - for (size_t i = arg_begin; i < args.size(); ++i) { - std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); - sargs.push_back(std::move(val)); - } - - // Emit a scalar call for each lane. - for (int i = 0; i < ret_dtype.lanes(); ++i) { - std::ostringstream scall; - scall << global_symbol << "("; - for (size_t j = 0; j < sargs.size(); ++j) { - if (j > 0) scall << ", "; - PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); - } - scall << ")"; - PrintVecElemStore(sret, ret_dtype, i, scall.str()); - } - } - os << sret; - } else { - CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); - } -} - void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { - if (auto* ptr_op = op->op.as()) { - Op call_op = GetRef(ptr_op); - // This is only for backward compatibility with __shfl_{up/down}. - // A macro will be used to replace *_sync calls to legacy ones. - if (op_need_warp_shuffle_.get(call_op, false)) { - enable_warp_shuffle_ = true; - } + // This is only for backward compatibility with __shfl_{up/down}. + // A macro will be used to replace *_sync calls to legacy ones. + if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") || + op->is_intrinsic("__shfl_down_sync")) { + enable_warp_shuffle_ = true; } - if (op->op.same_as(builtin::tvm_fill_fragment())) { + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; @@ -503,7 +447,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "], "; this->PrintExpr(op->args[5], os); os << ")"; - } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; @@ -515,7 +459,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << ", "; this->PrintExpr(op->args[6], os); os << ")"; - } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; @@ -532,7 +476,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { LOG(FATAL) << "Invalid parameters"; } os << ")"; - } else if (op->op.same_as(builtin::tvm_mma_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::mma_sync("; @@ -542,7 +486,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->op.same_as(builtin::tvm_bmma_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::bmma_sync("; @@ -552,6 +496,51 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } + } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) { + // + // Emit an unsupported vector call + // + // v = intrin_f((float4*)A[0], (float4*)B[0]) + // + // as + // + // float4 __ret; + // { + // float4 __arg0 = ((float4*)A)[0]; + // float4 __arg1 = ((float4*)B)[0]; + // __ret.x = intrin_f(__arg0.x, __arg1.x); + // __ret.y = intrin_f(__arg0.y, __arg1.y); + // __ret.z = intrin_f(__arg0.z, __arg1.z); + // __ret.w = intrin_f(__arg0.w, __arg1.w); + // } + // v = __ret; + // + // Declare the result vector. + std::string sret = GetUniqueName("_"); + this->PrintIndent(); + this->PrintType(op->dtype, stream); + stream << ' ' << sret << ";\n"; + { + // Load arguments. + std::vector sargs; + for (size_t i = 0; i < op->args.size(); ++i) { + std::string val = SSAGetID(PrintExpr(op->args[i]), op->args[i].dtype()); + sargs.push_back(std::move(val)); + } + + // Emit a scalar call for each lane. + for (int i = 0; i < op->dtype.lanes(); ++i) { + std::ostringstream scall; + scall << op->name << "("; + for (size_t j = 0; j < op->args.size(); ++j) { + if (j > 0) scall << ", "; + PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall); + } + scall << ")"; + PrintVecElemStore(sret, op->dtype, i, scall.str()); + } + } + os << sret; } else { CodeGenC::VisitExpr_(op, os); } @@ -611,7 +600,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); - if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { + if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { PrintIndent(); stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; PrintIndent(); diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 3cde8e379eb4c..f9ab0ade2cf20 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -26,7 +26,6 @@ #include #include -#include #include #include @@ -69,10 +68,6 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; - protected: - void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, - bool skip_first_arg, std::ostream& os) final; // NOLINT(*) - private: // Handle volatile loads void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; @@ -96,8 +91,6 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; - // Op attribute map - OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 1c4256c5a1661..2c26ee977639d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -270,7 +270,7 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N } void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->op.same_as(builtin::reinterpret())) { + if (op->is_intrinsic(CallNode::reinterpret)) { // generate as_type(ARG) os << "(as_type<"; this->PrintType(op->dtype, os); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 53a2799e2725b..45746b8ef721f 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -21,9 +21,6 @@ * \file intrin_rule_cuda.cc * \brief CUDA intrinsic rules. */ -#include -#include - #include "../intrin_rule.h" namespace tvm { @@ -96,23 +93,23 @@ struct CUDAPopcount { }; struct CUDAWarpIntrinsic { - const Op operator()(DataType t, const Op& orig_op) const { - if (orig_op.same_as(builtin::tvm_warp_shuffle())) { - return Op::Get("tir.cuda.__shfl_sync"); - } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { - return Op::Get("tir.cuda.__shfl_up_sync"); - } else { - CHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); - return Op::Get("tir.cuda.__shfl_down_sync"); + const char* operator()(DataType t, const std::string& name) const { + if (name == intrinsic::tvm_warp_shuffle) { + return "__shfl_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_up) { + return "__shfl_up_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_down) { + return "__shfl_down_sync"; + } + if (name == intrinsic::tvm_warp_activemask) { + return "__activemask"; } + return ""; } }; -static void DispatchCUDAWarpActiveMask(const TVMArgs& args, TVMRetValue* rv) { - Call call = args[0]; - *rv = Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, CallNode::PureExtern); -} - template static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; @@ -120,9 +117,8 @@ static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; - - *rv = - Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args, CallNode::PureExtern); + const char* name = T()(call->dtype, call->name); + *rv = Call(call->dtype, name, cuda_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); @@ -179,32 +175,10 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") .set_body(DispatchCUDAShuffle); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") - .set_body(DispatchCUDAWarpActiveMask); + .set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); -// Register low-level builtin ops. -// TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. -TVM_REGISTER_OP("tir.cuda.__shfl_sync") - .set_num_inputs(4) - .set_attr("TGlobalSymbol", "__shfl_sync") - .set_attr("cuda.need_warp_shuffle", true); - -TVM_REGISTER_OP("tir.cuda.__shfl_up_sync") - .set_num_inputs(4) - .set_attr("TGlobalSymbol", "__shfl_up_sync") - .set_attr("cuda.need_warp_shuffle", true); - -TVM_REGISTER_OP("tir.cuda.__shfl_down_sync") - .set_num_inputs(4) - .set_attr("TGlobalSymbol", "__shfl_down_sync") - .set_attr("cuda.need_warp_shuffle", true); - -TVM_REGISTER_OP("tir.cuda.__activemask") - .set_num_inputs(0) - .set_attr("TGlobalSymbol", "__activemask") - .set_attr("cuda.need_warp_shuffle", true); - } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 82eabdd96dfe6..8453b33f8a431 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -79,8 +79,8 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { arith::Analyzer analyzer; CHECK(analyzer.CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; - Array opencl_args{{StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - *rv = Call(call->dtype, builtin::call_extern(), opencl_args, CallNode::PureExtern); + Array opencl_args{{call->args[1], call->args[2]}}; + *rv = Call(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 6c12343c81ecc..699d3953f04c9 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -24,7 +24,6 @@ #include "codegen_spirv.h" #include -#include #include #include @@ -237,7 +236,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::call_spirv_glsl450())) { + if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; @@ -245,31 +244,31 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { values.push_back(MakeValue(op->args[i])); } return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); - } else if (op->op.same_as(builtin::bitwise_and())) { + } else if (op->is_intrinsic(CallNode::bitwise_and)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b); - } else if (op->op.same_as(builtin::bitwise_xor())) { + } else if (op->is_intrinsic(CallNode::bitwise_xor)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b); - } else if (op->op.same_as(builtin::bitwise_or())) { + } else if (op->is_intrinsic(CallNode::bitwise_or)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b); - } else if (op->op.same_as(builtin::bitwise_not())) { + } else if (op->is_intrinsic(CallNode::bitwise_not)) { CHECK_EQ(op->args.size(), 1U); spirv::Value a = MakeValue(op->args[0]); return builder_->MakeValue(spv::OpNot, a.stype, a); - } else if (op->op.same_as(builtin::shift_left())) { + } else if (op->is_intrinsic(CallNode::shift_left)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b); - } else if (op->op.same_as(builtin::shift_right())) { + } else if (op->is_intrinsic(CallNode::shift_right)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); spirv::Value b = MakeValue(op->args[1]); @@ -278,18 +277,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { } else { return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b); } - } else if (op->op.same_as(builtin::reinterpret())) { + } else if (op->is_intrinsic(CallNode::reinterpret)) { return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype), MakeValue(op->args[0])); - } else if (op->op.same_as(builtin::large_uint_imm())) { + } else if (op->is_intrinsic(intrinsic::tvm_large_uint_imm)) { CHECK_EQ(op->args.size(), 2U); uint64_t low = static_cast(Downcast(op->args[0])->value); uint64_t high = static_cast(Downcast(op->args[1])->value); uint64_t val = (high << 32U) | low; return builder_->UIntImm(builder_->GetSType(op->dtype), val); - } else if (op->op.same_as(builtin::tvm_storage_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return this->CreateStorageSync(op); - } else if (op->op.same_as(builtin::if_then_else())) { + } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { CHECK_EQ(op->args.size(), 3U); spirv::Value cond = MakeValue(op->args[0]); spirv::Label then_label = builder_->NewLabel(); @@ -313,14 +312,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(0, then_value, then_value_label); phi.SetIncoming(1, else_value, else_value_label); return phi; - } else if (op->op.same_as(builtin::popcount())) { + } else if (op->is_intrinsic("popcount")) { return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), MakeValue(op->args[0])); } else { if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->op << " with return type " << op->dtype; + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->op << " with return type " << op->dtype; + LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 1b9d2e4e410d0..a6b254770daa2 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include namespace tvm { @@ -44,8 +43,7 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::Call(call->dtype, tir::builtin::call_spirv_glsl450(), cargs, - tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 84b14925877a6..6dd2ca0ecb6c4 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include @@ -42,31 +41,31 @@ using namespace tir; // map struct field kind to runtime variants // We keep two separate enums to ensure runtime/compiler isolation. StackVM::StructFieldKind MapFieldKind(int64_t kind) { - auto val = static_cast(kind); + auto val = static_cast(kind); switch (val) { - case builtin::kArrData: + case intrinsic::kArrData: return StackVM::kArrData; - case builtin::kArrShape: + case intrinsic::kArrShape: return StackVM::kArrShape; - case builtin::kArrAddr: + case intrinsic::kArrAddr: return StackVM::kArrAddr; - case builtin::kArrStrides: + case intrinsic::kArrStrides: return StackVM::kArrStrides; - case builtin::kArrNDim: + case intrinsic::kArrNDim: return StackVM::kArrNDim; - case builtin::kArrTypeCode: + case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode; - case builtin::kArrTypeBits: + case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits; - case builtin::kArrTypeLanes: + case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes; - case builtin::kArrByteOffset: + case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset; - case builtin::kArrDeviceId: + case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId; - case builtin::kArrDeviceType: + case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType; - case builtin::kTVMValueContent: + case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent; default: LOG(FATAL) << "Do not know how to map field " << kind; @@ -175,7 +174,7 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { } void CodeGenStackVM::VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::address_of())) { + if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); @@ -183,9 +182,9 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->PushOp(StackVM::PUSH_I64, l->dtype.element_of().bytes()); this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::ADDR_ADD); - } else if (op->op.same_as(builtin::reinterpret())) { + } else if (op->is_intrinsic(CallNode::reinterpret)) { this->Push(op->args[0]); - } else if (op->op.same_as(builtin::tvm_struct_get())) { + } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; this->Push(op->args[0]); @@ -198,7 +197,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = MapFieldKind(kind); vm_.code.push_back(code); - } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { + } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) { CHECK_GE(op->args.size(), 5U); const StringImmNode* s = op->args[0].as(); CHECK(s != nullptr) << "tvm_call_global expect first argument as function name"; @@ -227,7 +226,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { vm_.code.push_back(code); code.v_int = end; vm_.code.push_back(code); - } else if (op->op.same_as(builtin::tvm_stack_alloca())) { + } else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; const IntImmNode* num = op->args[1].as(); @@ -250,7 +249,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { // add stack size to be safe. vm_.stack_size += size; this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast(size)); - } else if (op->op.same_as(backend_alloc_workspace_op_)) { + } else if (op->name == "TVMBackendAllocWorkspace") { CHECK_EQ(op->args.size(), 5U); this->Push(op->args[0]); this->Push(op->args[1]); @@ -258,21 +257,21 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { this->Push(op->args[3]); this->Push(op->args[4]); this->PushOp(StackVM::TVM_DEVICE_ALLOCA); - } else if (op->op.same_as(backend_free_workspace_op_)) { + } else if (op->name == "TVMBackendFreeWorkspace") { CHECK_EQ(op->args.size(), 3U); this->Push(op->args[0]); this->Push(op->args[1]); this->Push(op->args[2]); this->PushOp(StackVM::TVM_DEVICE_FREE); - } else if (op->op.same_as(builtin::tvm_throw_last_error())) { + } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { this->PushOp(StackVM::TVM_THROW_LAST_ERROR); - } else if (op->op.same_as(builtin::isnullptr())) { + } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); this->Push(op->args[0]); this->PushOp(StackVM::PUSH_I64, 0); this->PushOp(StackVM::EQ_HANDLE); } else { - LOG(FATAL) << "unknown function call " << op->op; + LOG(FATAL) << "unknown function call " << op->name; } } @@ -431,7 +430,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { if (is_const(ev->value)) return; const CallNode* op = ev->value.as(); - if (op && op->op.same_as(builtin::tvm_struct_set())) { + if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(op->args.size(), 4U); this->Push(op->args[0]); this->Push(op->args[3]); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 480ffc7eb8705..b77c40696de6c 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include @@ -154,9 +153,6 @@ class CodeGenStackVM : public ExprFunctor, std::unordered_map str_idmap_; /*! \brief id of each global function */ std::unordered_map extern_fun_idmap_; - - Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace"); - Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace"); }; } // namespace codegen diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index f6254121b7cb7..1834aa3decf7b 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -96,30 +96,31 @@ class JacobianMutator : public ExprMutator { PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); if (op->call_type == CallNode::CallType::PureIntrinsic) { - if (op->op.same_as(op_exp_)) { + static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; + if (op->name == "exp") { return Mul(Mutate(op->args[0]), expr); - } else if (op->op.same_as(op_log_)) { + } else if (op->name == "log") { return Div(Mutate(op->args[0]), op->args[0]); - } else if (op->op.same_as(op_sigmoid_)) { + } else if (op->name == "sigmoid") { return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); - } else if (op->op.same_as(op_sqrt_)) { + } else if (op->name == "sqrt") { return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); - } else if (op->op.same_as(op_tanh_)) { + } else if (op->name == "tanh") { return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); - } else if (op->op.same_as(op_pow_)) { + } else if (op->name == "pow") { auto x = op->args[0], y = op->args[1]; return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); - } else if (op->op.same_as(op_fabs_)) { + } else if (op->name == "fabs") { auto type = op->args[0].dtype(); return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), FloatImm(type, 1.0), FloatImm(type, -1.0))); - } else if (op->op.same_as(op_if_then_else_)) { + } else if (op->name == intrinsic::tvm_if_then_else) { Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; - return Call(op->dtype, op->op, new_args, op->call_type); - } else if (piecewise_const.count(op->op)) { + return Call(op->dtype, op->name, new_args, op->call_type); + } else if (piecewise_const.count(op->name)) { return FloatImm(expr.dtype(), 0.0); } else { - LOG(FATAL) << "Derivative of this intrinsic is not implemented: " << op->op; + throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); } } NOT_IMPLEMENTED; @@ -280,17 +281,6 @@ class JacobianMutator : public ExprMutator { Array indices_; Var input_var_; arith::Analyzer analyzer_; - - const Op& op_exp_ = Op::Get("tir.exp"); - const Op& op_log_ = Op::Get("tir.log"); - const Op& op_sigmoid_ = Op::Get("tir.sigmoid"); - const Op& op_sqrt_ = Op::Get("tir.sqrt"); - const Op& op_tanh_ = Op::Get("tir.tanh"); - const Op& op_pow_ = Op::Get("tir.pow"); - const Op& op_fabs_ = Op::Get("tir.fabs"); - const Op& op_if_then_else_ = Op::Get("tir.if_then_else"); - std::unordered_set piecewise_const = { - Op::Get("tir.floor"), Op::Get("tir.ceil"), Op::Get("tir.trunc"), Op::Get("tir.round")}; }; PrimExpr Derivative(const PrimExpr& expr, const Var& var) { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index b4725c571782c..1fc0520143fbf 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -27,7 +27,6 @@ #include #include #include -#include #include #include @@ -279,7 +278,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, attr->dim_align_offset}; realize = tir::AttrStmt( t, tir::attr::buffer_dim_align, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), realize); } } diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index eeaab301ad035..e834ff279d05c 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -21,8 +21,6 @@ * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ -#include - #include "compute_op.h" #include "op_util.h" @@ -196,7 +194,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, // Apply the existing input predicate if any. output_preds.push_back(input_pred); - Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::builtin::tvm_thread_allreduce(), + Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, freduce_args, CallNode::Intrinsic)); reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, make_zero(DataType::Handle()), reduce_body); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 01019e43e61c0..ef55c44241b04 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -153,7 +153,7 @@ Stmt ExternOpNode::BuildProvide(const Stage& stage, tuple.push_back(buffer->shape[k]); } ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), builtin::tvm_tuple(), tuple, CallNode::Intrinsic), ret); + Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 714e8859229d3..8d5265bcb14f0 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include @@ -154,7 +153,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, } input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding @@ -178,7 +177,7 @@ Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index dd978a430e4bf..82832c927785c 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -370,7 +370,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, } input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -390,7 +390,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, Array bind_spec{buffer, tensor}; output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - Call(DataType::Handle(), tir::builtin::tvm_tuple(), tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 67121b881a33b..1ff569f29f1f0 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -43,6 +42,7 @@ namespace tvm { namespace te { using namespace te; +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; @@ -255,9 +255,9 @@ class BodyVisitor : public StmtExprVisitor { } } - void VisitExpr_(const ProducerLoadNode* op) final { + void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); - args_.insert(std::make_pair(op->producer->GetNameHint(), op->indices)); + args_.insert(std::make_pair(op->name, op->args)); } friend class ScheduleAnalyser; @@ -415,7 +415,7 @@ class BufferAnalyser : public StmtExprVisitor { } else if (op->attr_key == tir::attr::buffer_dim_align) { te::Tensor tensor = Downcast(op->node); const CallNode* tuple = op->value.as(); - CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); + CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); auto& vinfo = dim_align_[tensor]; size_t dim = tuple->args[0].as()->value; if (dim >= vinfo.size()) { @@ -848,13 +848,13 @@ class TensorCoreIRMutator : public StmtExprMutator { Buffer buffer_b(buffer_node_b); if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { return Evaluate( - Call(DataType::Handle(), builtin::tvm_bmma_sync(), + Call(DataType::Handle(), intrinsic::tvm_bmma_sync, {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); } else { return Evaluate( - Call(DataType::Handle(), builtin::tvm_mma_sync(), + Call(DataType::Handle(), intrinsic::tvm_mma_sync, {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, CallNode::Intrinsic)); @@ -879,7 +879,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto pload = dst.as(); auto fill_fragment_call = [this, &op](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_fill_fragment(), + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, op->value}, CallNode::Intrinsic)); @@ -889,11 +889,11 @@ class TensorCoreIRMutator : public StmtExprMutator { return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); } - const ProducerLoadNode* value = op->value.as(); + const CallNode* value = op->value.as(); CHECK(value != nullptr) << "Can only load fragment from a buffer"; - auto it = strides_.find(value->producer->GetNameHint()); - CHECK(it != strides_.end()) << "Cannot find stride for " << value->producer->GetNameHint(); + auto it = strides_.find(value->name); + CHECK(it != strides_.end()) << "Cannot find stride for " << value->name; auto strides = it->second; CHECK_GE(strides.size(), 2); PrimExpr stride = strides[strides.size() - 2]; @@ -902,9 +902,7 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - // TODO(tvm-team) The extern function name seems to be a hack. - PrimExpr src = Call(value->dtype, builtin::call_extern(), {StringImm("&"), mutated_value}, - CallNode::Extern); + PrimExpr src = Call(value->dtype, "&", {mutated_value}, CallNode::Extern); auto pload = dst.as(); PrimExpr matrix_major; @@ -920,7 +918,7 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_load_matrix_sync(), + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, src, stride, matrix_major}, CallNode::Intrinsic)); @@ -943,13 +941,12 @@ class TensorCoreIRMutator : public StmtExprMutator { PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = - Call(DataType::Handle(), builtin::call_extern(), {StringImm("&"), dst}, CallNode::Extern); + dst = Call(DataType::Handle(), "&", {dst}, CallNode::Extern); auto pload = op->value.as(); auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { - return Evaluate(Call(DataType::Handle(), builtin::tvm_store_matrix_sync(), + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, buffer->elem_offset, dst, stride, StringImm("col_major")}, CallNode::Intrinsic)); @@ -1067,7 +1064,7 @@ class TensorCoreIRMutator : public StmtExprMutator { args.push_back(pload->indices[i]); args.push_back(shape[i]); } - auto tuple = Call(DataType::Handle(), builtin::tvm_tuple(), args, CallNode::Intrinsic); + auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); Array node = {buffer, tensor}; return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 12ec270a53cc5..8eb846b7d6181 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include @@ -121,7 +120,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; const CallNode* C = iter->second.as(); - if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false; + if (!C || C->name != intrinsic::tvm_struct_get) return false; V = C->args[0].as(); } return false; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 6cccfa0fcebff..4e433fc718b19 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include @@ -377,7 +376,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args, tir::CallNode::Intrinsic); + return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 4b20351e2053d..9390feada456e 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -698,21 +698,50 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Call -Call::Call(DataType dtype, RelayExpr op, Array args, CallType call_type) { +Call::Call(DataType dtype, String name, Array args, CallType call_type) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } ObjectPtr node = make_object(); node->dtype = dtype; - node->op = std::move(op); + node->name = std::move(name); node->args = std::move(args); node->call_type = call_type; data_ = std::move(node); } +const char* CallNode::vectorizable_intrinsics[] = {"floor", + "ceil", + "sign", + "trunc", + "fabs", + "round", + "exp", + "tanh", + "sqrt", + "log", + "sin", + "cos", + "pow", + "tan", + tir::CallNode::shift_left, + tir::CallNode::shift_right, + tir::CallNode::likely, + tir::CallNode::popcount}; + +bool CallNode::is_vectorizable() const { + size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); + for (size_t i = 0; i < cnt; ++i) { + if (name == CallNode::vectorizable_intrinsics[i]) { + return true; + } + } + return false; +} + TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, int call_type) { + .set_body_typed([](DataType type, String name, Array args, int call_type) { Array prim_expr_args; for (const auto& it : args) { CHECK(it->IsInstance() || it->IsInstance()); @@ -722,7 +751,7 @@ TVM_REGISTER_GLOBAL("tir.Call") prim_expr_args.push_back(Downcast(it)); } } - return Call(type, op, prim_expr_args, static_cast(call_type)); + return Call(type, name, prim_expr_args, static_cast(call_type)); }); TVM_REGISTER_NODE_TYPE(CallNode); @@ -730,13 +759,7 @@ TVM_REGISTER_NODE_TYPE(CallNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - if (auto* ptr_op = op->op.as()) { - p->stream << ptr_op->name << "("; - } else { - auto* ptr_gvar = op->op.as(); - CHECK(ptr_gvar != nullptr); - p->stream << "@" << ptr_gvar->name_hint << "("; - } + p->stream << op->name << "("; for (size_t i = 0; i < op->args.size(); ++i) { p->Print(op->args[i]); if (i < op->args.size() - 1) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 98b9fd02c09cf..b92127b24e2ba 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -166,7 +166,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return Call(op->dtype, op->op, args, op->call_type); + return Call(op->dtype, op->name, args, op->call_type); } } diff --git a/src/tir/op/op.cc b/src/tir/ir/op.cc similarity index 82% rename from src/tir/op/op.cc rename to src/tir/ir/op.cc index f8049eace356c..5ac9f5902c12b 100644 --- a/src/tir/op/op.cc +++ b/src/tir/ir/op.cc @@ -18,16 +18,12 @@ */ /*! - * \file tir/op/op.cc - * - * Common operator definitions for ops in tir/op.h + * \file expr_operator.cc */ #include -#include #include #include -#include #include // Centralized header for constant folders. @@ -37,12 +33,6 @@ namespace tvm { using namespace tir; -// macro to register an unary op -#define TIR_REGISTER_PURE_UNARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(1) - -// macro to register an binary op -#define TIR_REGISTER_PURE_BINARY_OP(OpName) TVM_REGISTER_OP(OpName).set_num_inputs(2) - runtime::DataType GetRuntimeDataType(const Type& type) { if (auto* n = type.as()) { return n->dtype; @@ -80,9 +70,8 @@ inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { return tir::Cast(t, value); } -// LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { - return tir::Call(t, tir::builtin::large_uint_imm(), + return tir::Call(t, tir::intrinsic::tvm_large_uint_imm, {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, tir::CallNode::PureIntrinsic); } @@ -259,13 +248,11 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } } -// reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::Call(t, tir::builtin::reinterpret(), {value}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); } -// operator+ PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); @@ -373,7 +360,6 @@ PrimExpr max(PrimExpr a, PrimExpr b) { return tir::Max(a, b); } -// if_then_else PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { CHECK(cond.dtype() == DataType::Bool(1)) << "if_then_else only accept the condition to be boolean type."; @@ -385,20 +371,15 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - - return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), + return tir::Call(true_value.dtype(), tir::intrinsic::tvm_if_then_else, {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } -// likely PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, tir::CallNode::PureIntrinsic); + return tir::Call(cond.dtype(), tir::CallNode::likely, {cond}, tir::CallNode::PureIntrinsic); } -TVM_REGISTER_OP("tir.likely").set_num_inputs(1); - -// operator> PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); @@ -464,7 +445,6 @@ PrimExpr operator!(PrimExpr a) { return tir::Not(a); } -// shirt right PrimExpr operator>>(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -480,11 +460,9 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - - return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::shift_right, {a, b}, tir::CallNode::PureIntrinsic); } -// shift left PrimExpr operator<<(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -500,10 +478,9 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::shift_left, {a, b}, tir::CallNode::PureIntrinsic); } -// bitwise and PrimExpr operator&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -512,10 +489,9 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_and, {a, b}, tir::CallNode::PureIntrinsic); } -// bitwise_or PrimExpr operator|(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -524,10 +500,9 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_or, {a, b}, tir::CallNode::PureIntrinsic); } -// bitwise_xor PrimExpr operator^(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()); @@ -536,30 +511,20 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, tir::CallNode::PureIntrinsic); } -// bitwie_not PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_not, {a}, tir::CallNode::PureIntrinsic); } -TVM_REGISTER_OP("tir.bitwise_not"); - -TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; }); - -// pow PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - static auto op = Op::Get("tir.pow"); - return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); } -TVM_REGISTER_OP("tir.pow").set_num_inputs(2).set_attr("TVectorizable", true); - -// abs PrimExpr abs(PrimExpr x) { if (x.dtype().is_int()) { using tir::IntImmNode; @@ -574,8 +539,7 @@ PrimExpr abs(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } - static auto op = Op::Get("tir.fabs"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { @@ -585,9 +549,6 @@ PrimExpr abs(PrimExpr x) { } } -TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr("TVectorizable", true); - -// isnan PrimExpr isnan(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -598,12 +559,12 @@ PrimExpr isnan(PrimExpr x) { if (fx) { return make_const(t, std::isnan(fx->value)); } - static auto op = Op::Get("tir.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))}, + return tir::Call(t, tir::CallNode::isnan, + {cast(DataType::Float(32, t.lanes()), std::move(x))}, tir::CallNode::PureIntrinsic); } else { - return tir::Call(t, op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; @@ -611,9 +572,6 @@ PrimExpr isnan(PrimExpr x) { } } -TIR_REGISTER_PURE_UNARY_OP("tir.isnan"); - -// isinf PrimExpr isinf(PrimExpr x) { DataType t = DataType::Bool(x.dtype().lanes()); if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -627,7 +585,6 @@ PrimExpr isinf(PrimExpr x) { } } -// isfinite PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); } PrimExpr sum(PrimExpr source, Array rdom) { @@ -680,17 +637,12 @@ PrimExpr prod(PrimExpr source, Array rdom) { return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } -// fmod PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - static auto op = Op::Get("tir.fmod"); - return tir::Call(x.dtype(), op, {x, y}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); } -TIR_REGISTER_PURE_UNARY_OP("tir.fmod"); - -// floor PrimExpr floor(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -698,13 +650,9 @@ PrimExpr floor(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); - static auto op = Op::Get("tir.floor"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); } -TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr("TVectorizable", true); - -// ceil PrimExpr ceil(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -712,13 +660,9 @@ PrimExpr ceil(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); - static auto op = Op::Get("tir.ceil"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); } -TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr("TVectorizable", true); - -// round PrimExpr round(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -726,13 +670,9 @@ PrimExpr round(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - static auto op = Op::Get("tir.round"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); } -TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr("TVectorizable", true); - -// nearbyint PrimExpr nearbyint(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -740,13 +680,9 @@ PrimExpr nearbyint(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - static auto op = Op::Get("tir.nearbyint"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); } -TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); - -// trunc PrimExpr trunc(PrimExpr x) { if (x.dtype().is_int() || x.dtype().is_uint()) { return x; @@ -756,72 +692,9 @@ PrimExpr trunc(PrimExpr x) { if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - static auto op = Op::Get("tir.trunc"); - return tir::Call(x.dtype(), op, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } -TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr("TVectorizable", true); - -// unary op registration. -TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.erf"); - -TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid"); - -TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt"); - -TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.log1p"); - -TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.popcount").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr("TVectorizable", true); - -TIR_REGISTER_PURE_UNARY_OP("tir.asin"); - -TIR_REGISTER_PURE_UNARY_OP("tir.acos"); - -TIR_REGISTER_PURE_UNARY_OP("tir.atan"); - -TIR_REGISTER_PURE_UNARY_OP("tir.acosh"); - -TIR_REGISTER_PURE_UNARY_OP("tir.asinh"); - -TIR_REGISTER_PURE_UNARY_OP("tir.atanh"); - -// binary intrinsics -TIR_REGISTER_PURE_BINARY_OP("tir.atan2"); - -TIR_REGISTER_PURE_BINARY_OP("tir.nextafter"); - -TIR_REGISTER_PURE_BINARY_OP("tir.hypot"); - -TIR_REGISTER_PURE_BINARY_OP("tir.copysign"); - -TIR_REGISTER_PURE_BINARY_OP("tir.ldexp"); - // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { if (args[0].type_code() == kDLInt) { @@ -910,5 +783,4 @@ TVM_REGISTER_GLOBAL("tir._OpIfThenElse") .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { return if_then_else(cond, true_value, false_value); }); - } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index c3ddb6625d53d..66497755c88ab 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -582,13 +582,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->PrintIndent(); p->stream << "}\n"; }); - -PrimExpr TypeAnnotation(DataType dtype) { - static auto op = Op::Get("tir.type_annotation"); - return tir::Call(dtype, op, {}, tir::CallNode::PureIntrinsic); -} - -TVM_REGISTER_OP("tir.type_annotation"); - } // namespace tir } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc deleted file mode 100644 index 8efcf3ff49254..0000000000000 --- a/src/tir/op/builtin.cc +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tir/op/builtin.cc - * - * builtin intrinsic operators. - */ -#include -#include -#include -#include - -namespace tvm { -namespace tir { -namespace builtin { - -#define TIR_DEFINE_BUILTIN_FUNC(OpName) \ - const Op& OpName() { \ - static const Op& op = Op::Get("tir." #OpName); \ - return op; \ - } \ - TVM_REGISTER_OP("tir." #OpName) - -TIR_DEFINE_BUILTIN_FUNC(reinterpret).set_num_inputs(1); - -TIR_DEFINE_BUILTIN_FUNC(likely).set_num_inputs(1).set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(bitwise_and) - .set_num_inputs(2) - .set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(bitwise_or) - .set_num_inputs(2) - .set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(bitwise_xor) - .set_num_inputs(2) - .set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(bitwise_not) - .set_num_inputs(1) - .set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(shift_left) - .set_num_inputs(2) - .set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(shift_right) - .set_num_inputs(2) - .set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(large_uint_imm).set_num_inputs(2); - -TIR_DEFINE_BUILTIN_FUNC(address_of).set_num_inputs(1); - -TIR_DEFINE_BUILTIN_FUNC(if_then_else).set_num_inputs(3); - -TIR_DEFINE_BUILTIN_FUNC(isnullptr).set_num_inputs(1); - -TIR_DEFINE_BUILTIN_FUNC(isnan).set_num_inputs(1); - -TIR_DEFINE_BUILTIN_FUNC(popcount).set_num_inputs(1); - -TIR_DEFINE_BUILTIN_FUNC(fma).set_num_inputs(3).set_attr("TVectorizable", true); - -TIR_DEFINE_BUILTIN_FUNC(call_extern); - -TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin); - -TIR_DEFINE_BUILTIN_FUNC(call_spirv_glsl450); - -TIR_DEFINE_BUILTIN_FUNC(prefetch); - -TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr).set_num_inputs(5); - -TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle).set_num_inputs(0); - -TIR_DEFINE_BUILTIN_FUNC(tvm_context_id).set_num_inputs(0); - -TIR_DEFINE_BUILTIN_FUNC(tvm_tuple); - -TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get).set_num_inputs(3); - -TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set).set_num_inputs(4); - -TIR_DEFINE_BUILTIN_FUNC(tvm_throw_last_error).set_num_inputs(0); - -TIR_DEFINE_BUILTIN_FUNC(tvm_stack_alloca).set_num_inputs(2); - -TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_shape); - -TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array).set_num_inputs(6); - -// When num_inputs are not set, the function is assumed to be variable length. -TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed); - -TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed); - -TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context).set_num_inputs(1); - -TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered); - -TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered); - -// TODO(tvm-team) revisit storage sync once we have a good memory hierachy structure. -TIR_DEFINE_BUILTIN_FUNC(tvm_storage_sync); - -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle); - -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_up); - -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down); - -TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask); - -TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit); - -TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce); - -TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync); - -TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync); - -TIR_DEFINE_BUILTIN_FUNC(tvm_bmma_sync); - -TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment); - -TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync); - -TIR_DEFINE_BUILTIN_FUNC(vectorhigh); - -TIR_DEFINE_BUILTIN_FUNC(vectorlow); - -TIR_DEFINE_BUILTIN_FUNC(vectorcombine); - -} // namespace builtin -} // namespace tir -} // namespace tvm diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc deleted file mode 100644 index 1c540e3a650aa..0000000000000 --- a/src/tir/op/runtime.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tir/op/runtime.cc - * \brief TIR ops for runtime functions. - */ -#include -#include - -namespace tvm { -namespace tir { - -TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace") - .set_num_inputs(5) - .set_attr("TGlobalSymbol", "TVMBackendAllocWorkspace"); - -TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace") - .set_num_inputs(3) - .set_attr("TGlobalSymbol", "TVMBackendFreeWorkspace"); - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 80c526827ad50..ae7065d94d80e 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -24,7 +24,6 @@ #include "arg_binder.h" #include -#include #include #include @@ -142,7 +141,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st } } -inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kind) { +inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind kind) { return TVMStructGet(t, arr, 0, kind); } @@ -153,7 +152,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); // dimension checks - PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim); + PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); @@ -163,11 +162,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, DataType dtype = buffer->dtype; std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; - PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) == + PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == IntImm(DataType::UInt(8), dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == + TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == IntImm(DataType::UInt(8), dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == + TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == IntImm(DataType::UInt(16), dtype.lanes())); if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); @@ -175,7 +174,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); } // data field - if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, builtin::kArrData), + if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), arg_name + ".data", true)) { Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); @@ -187,7 +186,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); init_nest_.emplace_back( - LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop)); + LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; @@ -203,9 +202,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back( - LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); + LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); PrimExpr is_null = - Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}, CallNode::PureIntrinsic); + Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -263,12 +262,12 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, if (const auto* const_offset = buffer->elem_offset.as()) { Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset), + TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, cast(buffer->elem_offset.dtype(), - (TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset) / + (TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset) / make_const(DataType::UInt(64), data_bytes))), arg_name + ".elem_offset", true)) { if (buffer->offset_factor > 1) { @@ -281,9 +280,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, } } // device info. - Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType), + Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), arg_name + ".device_type", true); - Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId), + Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 9722d1100a7e9..445ac1cf60cd2 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -23,7 +23,6 @@ */ #include -#include #include #include @@ -189,13 +188,13 @@ class BF16LowerRewriter : StmtExprMutator { auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); auto uint32_v = Cast(uint32_dtype, op_val); // to be endian invariant. - return Call(op->dtype, builtin::reinterpret(), {uint32_v << 16}, CallNode::PureIntrinsic); + return Call(op->dtype, CallNode::reinterpret, {uint32_v << 16}, CallNode::PureIntrinsic); } else if (op->dtype.is_bfloat16()) { // if is cast_to_bf16, check if op->value is fp32 CHECK(op->value->dtype.is_float() && op->value->dtype.bits() == 32); auto uint32_dtype = DataType(kDLUInt, 32, op_val->dtype.lanes()); - auto uint32_v = Call(uint32_dtype, builtin::reinterpret(), {op_val}, CallNode::PureIntrinsic); + auto uint32_v = Call(uint32_dtype, CallNode::reinterpret, {op_val}, CallNode::PureIntrinsic); auto uint16_dtype = DataType(kDLUInt, 16, op_val->dtype.lanes()); /* the following TIR is equivalent to the C++ code below: uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 3b6af0644fc9d..94464a04f912c 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -24,7 +24,6 @@ #include #include -#include #include #include #include @@ -67,7 +66,7 @@ class BoundChecker : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (process_store_ && op->op.same_as(builtin::if_then_else())) { + if (process_store_ && op->is_intrinsic(intrinsic::tvm_if_then_else)) { unsafe_rewritten_ = true; } return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 0485bb1f76138..73bf4c6f6db24 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include #include @@ -41,7 +40,7 @@ namespace tir { class ContextCallCombiner final : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_thread_context())) { + if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); PrimExpr ctx = op->args[0]; auto it = ctx_map_.find(ctx); @@ -49,7 +48,13 @@ class ContextCallCombiner final : public StmtExprMutator { return it->second; } else { CHECK(ctx.dtype().is_handle()); - Var ctx_var("ctx_cache_", ctx.dtype()); + std::string name; + if (const CallNode* call = ctx.as()) { + name = call->name + "_cache"; + } else { + name = "ctx_cache_"; + } + Var ctx_var(name, ctx.dtype()); ctx_map_[ctx] = ctx_var; return std::move(ctx_var); } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 092a7cdeca98b..384dbcb0caee0 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -21,7 +21,6 @@ * \file coproc_sync.cc */ #include -#include #include #include #include @@ -55,7 +54,7 @@ class CoProcTouchedBuffer : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { const VarNode* buffer = op->args[1].as(); if (in_scope_) { touched_[buffer].coproc = true; @@ -196,8 +195,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return { - Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}, CallNode::Intrinsic))}; + return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -210,8 +208,8 @@ class CoProcBarrierDetector : public StorageAccessVisitor { explicit CoProcBarrierDetector(const std::unordered_set& touched, const std::string& coproc_name) : touched_(touched) { - read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier"; - write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier"; + read_barrier_name_ = coproc_name + ".coproc_read_barrier"; + write_barrier_name_ = coproc_name + ".coproc_write_barrier"; } void PlanReadBarrier(const Stmt& stmt) { @@ -333,7 +331,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; - return Evaluate(Call(DataType::Int(32), Op::Get(func), + return Evaluate(Call(DataType::Int(32), func, {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); } @@ -348,8 +346,8 @@ class CoProcInstDepDetector : public StmtVisitor { public: explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { - sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push"); - sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop"); + sync_push_name_ = coproc_name + ".coproc_dep_push"; + sync_pop_name_ = coproc_name + ".coproc_dep_pop"; } void Plan(const Stmt& stmt) { @@ -557,12 +555,12 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_push_op_, + return Evaluate(Call(DataType::Int(32), sync_push_name_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_pop_op_, + return Evaluate(Call(DataType::Int(32), sync_pop_name_, {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, CallNode::Intrinsic)); } @@ -570,7 +568,7 @@ class CoProcInstDepDetector : public StmtVisitor { SyncState first_state_, last_state_, curr_state_; // Variables IterVar coproc_axis_; - Op sync_push_op_, sync_pop_op_; + std::string sync_push_name_, sync_pop_name_; }; class CoProcSyncInserter : public StmtMutator { diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 7180dd29d9039..042ddab15a2f8 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -21,7 +21,6 @@ * \file inject_virtual_thread.cc */ #include -#include #include #include #include @@ -55,7 +54,7 @@ class ExprTouched final : public StmtExprVisitor { } void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); CHECK(buffer_var); @@ -220,7 +219,7 @@ class VTInjector : public StmtExprMutator { } // Expression. PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -231,9 +230,9 @@ class VTInjector : public StmtExprMutator { PrimExpr extent = this->VisitExpr(op->args[3]); PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}, + return Call(op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]}, op->call_type); - } else if (op->op.same_as(builtin::tvm_context_id())) { + } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { return allow_share_ ? GetRef(op) : var_; } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/ir_util.h b/src/tir/transforms/ir_util.h index 758923b15af9a..6c0eeea97278b 100644 --- a/src/tir/transforms/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -25,7 +25,6 @@ #define TVM_TIR_TRANSFORMS_IR_UTIL_H_ #include -#include #include #include @@ -84,10 +83,10 @@ inline Array UpdateArray(Array arr, F fupdate) { * \return the get expression. */ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, - builtin::TVMStructFieldKind kind) { + intrinsic::TVMStructFieldKind kind) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind))}; - return Call(dtype, builtin::tvm_struct_get(), args, CallNode::PureIntrinsic); + return Call(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); } /*! @@ -97,7 +96,7 @@ inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return Call(DataType::Handle(), builtin::address_of(), + return Call(DataType::Handle(), intrinsic::tvm_address_of, {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), const_true(dtype.lanes()))}, CallNode::PureIntrinsic); @@ -114,7 +113,7 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { offset = offset * make_const(offset.dtype(), dtype.lanes()); offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return Call(DataType::Handle(), builtin::address_of(), + return Call(DataType::Handle(), intrinsic::tvm_address_of, {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } @@ -126,10 +125,11 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \param value The value to be set. * \return the set stmt. */ -inline Stmt TVMStructSet(Var handle, int index, builtin::TVMStructFieldKind kind, PrimExpr value) { +inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind, + PrimExpr value) { Array args = {handle, make_const(DataType::Int(32), index), make_const(DataType::Int(32), static_cast(kind)), value}; - return Evaluate(Call(DataType::Int(32), builtin::tvm_struct_set(), args, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } /*! diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 2fb8003486f17..3b2580c600743 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -141,11 +140,11 @@ class CandidateSelector final : public StmtExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::likely())) { + if (op->is_intrinsic(CallNode::likely)) { in_likely_ = true; StmtExprVisitor::VisitExpr_(op); in_likely_ = false; - } else if (op->op.same_as(builtin::tvm_thread_allreduce())) { + } else if (op->is_intrinsic(intrinsic::tvm_thread_allreduce)) { // no split if the body contains allreduce. no_split_ = true; return; @@ -215,7 +214,7 @@ class PartitionFinder : public StmtExprVisitor { } void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::likely())) { + if (op->is_intrinsic(CallNode::likely)) { PrimExpr cond = op->args[0]; if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is @@ -597,7 +596,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b class RemoveLikelyTags : public StmtExprMutator { public: PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::likely())) { + if (op->is_intrinsic(CallNode::likely)) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); } else { diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index fac50a08a9b72..9d6b47a1ca37a 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -25,7 +25,6 @@ #include #include #include -#include #include #include @@ -80,7 +79,7 @@ class StorageAccessInfoLower : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { return MakeAccessPtr(op); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index d38cb7b360422..c7aa949924d7b 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -51,17 +51,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { - // NOTE: call_type will eventually be deprecated and the information - // will be folded into Op's attr if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { - if (auto* ptr_op = op->op.as()) { - // Still use legacy string based rewriting - // TODO(tvm-team): migrate the pattern application from global function look up - // to an OpAttrMap - std::string name = ptr_op->name; - PrimExpr r = ApplyPattern(name, GetRef(op)); - if (r.defined()) return r; - } + PrimExpr r = ApplyPattern(op->name, GetRef(op)); + if (r.defined()) return r; } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -238,7 +230,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = (*fma_)(Call(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -249,11 +241,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - PrimExpr ApplyPattern(std::string name, const PrimExpr& e) { - if (name.compare(0, 4, "tir.") == 0) { - name = name.substr(4); - } - + PrimExpr ApplyPattern(const std::string& name, const PrimExpr& e) { for (size_t i = 0; i < patterns_.size(); ++i) { std::string& p = patterns_[i]; size_t psize = p.length(); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index dab8d5a78d029..ee17f081c6d83 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -72,7 +71,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); const CallNode* call = op->value.as(); - if (call && call->op.same_as(builtin::tvm_thread_allreduce())) { + if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { return MakeAllreduce(call); } else { return stmt; @@ -243,7 +242,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { { PrimExpr pred = const_true(1); PrimExpr mask = - Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); + Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); seq.emplace_back(Store(mask_var, mask, index, pred)); // Push allocation with an empty body. Later this will be fixed // when the entire body is ready. @@ -274,7 +273,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The former may cause dead lock as there is a divergent // branch with a warp sync call inside. // - PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_var, val, offset); + const char* shfl_func = intrinsic::tvm_warp_shuffle_down; + PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); const AllocateNode* repl = local_vars[i].as(); Stmt s = Store(repl->buffer_var, other, index, pred); seq.push_back(s); @@ -303,8 +303,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); + const char* shfl_func = intrinsic::tvm_warp_shuffle; PrimExpr val = Load(types[i], var, index, pred); - PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), mask_var, val, 0); + PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); seq.push_back(Store(var, splat, index, pred)); } @@ -464,18 +465,18 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // sync thread op. static Stmt SyncThread(const std::string& sync) { - return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), {StringImm(sync)}, + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)}, CallNode::Intrinsic)); } - // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Var mask_var, PrimExpr val, int delta_or_lane) { + // Emit warp shuffle intrinsic calls. + PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { PrimExpr pred = const_true(1); PrimExpr index(0); PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; - return Call(val.dtype(), op, args, CallNode::Intrinsic); + return Call(val.dtype(), name, args, CallNode::Intrinsic); } // Check if this is a reduction on threadIdx.x and its extent matches diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e6182301a3358..7611e0fcc8b38 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -22,7 +22,6 @@ * \file tir/transforms/lower_tvm_buildin.cc */ #include -#include #include #include #include @@ -41,7 +40,7 @@ inline PrimExpr ConstInt32(size_t index) { inline PrimExpr StackAlloca(std::string type, size_t num) { Array args = {StringImm(type), ConstInt32(num)}; - return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args, CallNode::Intrinsic); + return Call(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -104,22 +103,23 @@ class BuiltinLower : public StmtExprMutator { CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = - Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}, CallNode::Intrinsic)); + Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}, - CallNode::PureIntrinsic), + Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), throw_last_error), op->body}); + Stmt alloca = LetStmt( op->buffer_var, - Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), IntImm(DataType::Int(32), op->dtype.bits())}, CallNode::Extern), body); - PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), + PrimExpr free_op = Call(DataType::Int(32), "TVMBackendFreeWorkspace", {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), op->buffer_var}, CallNode::Extern); @@ -144,15 +144,15 @@ class BuiltinLower : public StmtExprMutator { } } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_call_packed())) { + if (op->is_intrinsic(intrinsic::tvm_call_packed)) { return MakeCallPacked(op); - } else if (op->op.same_as(builtin::tvm_call_trace_packed())) { + } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) { return MakeCallTracePacked(op); - } else if (op->op.same_as(builtin::tvm_stack_make_shape())) { + } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) { return MakeShape(op); - } else if (op->op.same_as(builtin::tvm_stack_make_array())) { + } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) { return MakeArray(op); - } else if (op->op.same_as(builtin::tvm_context_id())) { + } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { return make_zero(op->dtype); } else { return StmtExprMutator::VisitExpr_(op); @@ -176,21 +176,21 @@ class BuiltinLower : public StmtExprMutator { run_array_stack_ += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrData, op->args[0])); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrShape, op->args[1])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrStrides, strides)); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrNDim, op->args[3])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, builtin::kArrTypeCode, + TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeBits, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrTypeLanes, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); @@ -198,15 +198,15 @@ class BuiltinLower : public StmtExprMutator { if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrByteOffset, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceId, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, builtin::kArrDeviceType, + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, cast(DataType::Int(32), device_type_))); - return TVMStructGet(DataType::Handle(), stack_array_, idx, builtin::kArrAddr); + return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. PrimExpr MakeCallPacked(const CallNode* op) { @@ -226,7 +226,7 @@ class BuiltinLower : public StmtExprMutator { arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; @@ -245,7 +245,7 @@ class BuiltinLower : public StmtExprMutator { Array packed_args = {op->args[0], stack_value_, stack_tcode_, ConstInt32(arg_stack_begin), ConstInt32(arg_stack_begin + op->args.size() - 1)}; - return Call(DataType::Int(32), builtin::tvm_call_packed_lowered(), packed_args, + return Call(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, CallNode::Intrinsic); } @@ -267,7 +267,7 @@ class BuiltinLower : public StmtExprMutator { arg = Cast(api_type, arg); } prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), - builtin::kTVMValueContent, arg)); + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( @@ -287,7 +287,7 @@ class BuiltinLower : public StmtExprMutator { ConstInt32(arg_stack_begin + op->args.size() - 1), // Pass traced value. op->args[args_size - 1]}; - return Call(op->dtype, builtin::tvm_call_trace_packed_lowered(), packed_args, + return Call(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, CallNode::Intrinsic); } @@ -295,8 +295,8 @@ class BuiltinLower : public StmtExprMutator { bool IsArrayHandle(const PrimExpr& arg) { // specially set array handle. if (const CallNode* buf = arg.as()) { - if (buf->op.same_as(builtin::tvm_struct_get()) && - buf->args[2].as()->value == builtin::kArrAddr) { + if (buf->is_intrinsic(intrinsic::tvm_struct_get) && + buf->args[2].as()->value == intrinsic::kArrAddr) { return true; } } diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 3e7d13b2ff6ed..92f9ab54adb45 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -251,8 +250,8 @@ class WarpAccessRewriter : protected StmtExprMutator { << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); PrimExpr mask = - Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}, CallNode::Intrinsic); - return Call(load_value.dtype(), builtin::tvm_warp_shuffle(), + Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + return Call(load_value.dtype(), intrinsic::tvm_warp_shuffle, {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 9bb5fc6b59716..a91e350e6b22e 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -83,10 +82,10 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), builtin::kTVMValueContent)}; + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args, CallNode::PureIntrinsic); + PrimExpr res = Call(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { res = Cast(t, res); @@ -190,7 +189,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { Stmt set_device = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), + Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, CallNode::Intrinsic)); body = SeqStmt({set_device, body}); diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index a14fd02e7700e..07b0ea29a52a0 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -23,7 +23,6 @@ */ #include -#include #include #include @@ -319,8 +318,6 @@ class DataTypeRewriter : public StmtExprMutator { std::unordered_map ivmap_; // indicator of LoadNode::index and StoreNode::index bool is_index_{false}; - // cached ops - const Op& builtin_pow_ = Op::Get("tir.pow"); }; #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ @@ -355,23 +352,23 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { op = e.as(); CHECK(op != nullptr) << "Expected type to be CallNode" << ", but get " << e->GetTypeKey(); - - if (op->op.same_as(builtin::if_then_else())) { - return if_then_else(op->args[0], op->args[1], op->args[2]); - } else if (op->op.same_as(builtin::shift_right())) { - return op->args[0] >> op->args[1]; - } else if (op->op.same_as(builtin::shift_left())) { - return op->args[0] << op->args[1]; - } else if (op->op.same_as(builtin::bitwise_and())) { - return op->args[0] & op->args[1]; - } else if (op->op.same_as(builtin::bitwise_or())) { - return op->args[0] | op->args[1]; - } else if (op->op.same_as(builtin::bitwise_xor())) { - return op->args[0] ^ op->args[1]; - } else if (op->op.same_as(builtin_pow_)) { - return pow(op->args[0], op->args[1]); + if (op->call_type == CallNode::PureIntrinsic) { + if (op->name == intrinsic::tvm_if_then_else) { + return if_then_else(op->args[0], op->args[1], op->args[2]); + } else if (op->name == CallNode::shift_right) { + return op->args[0] >> op->args[1]; + } else if (op->name == CallNode::shift_left) { + return op->args[0] << op->args[1]; + } else if (op->name == CallNode::bitwise_and) { + return op->args[0] & op->args[1]; + } else if (op->name == CallNode::bitwise_or) { + return op->args[0] | op->args[1]; + } else if (op->name == CallNode::bitwise_xor) { + return op->args[0] ^ op->args[1]; + } else if (op->name == "pow") { + return pow(op->args[0], op->args[1]); + } } - return e; } diff --git a/src/tir/transforms/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc index e5535369c39e7..701f0cea1bfa8 100644 --- a/src/tir/transforms/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -22,7 +22,6 @@ * \brief Rewrite uinsafe select expression. */ #include -#include #include #include #include @@ -38,9 +37,9 @@ class UnsafeExprDetector : public ExprFunctor { // Because we will issue guard to make sure it is. bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::if_then_else())) { + if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { return VisitExpr(op->args[0]); - } else if (op->op.same_as(builtin::address_of())) { + } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); return this->VisitExpr(l->index); } else if (op->is_pure()) { @@ -105,7 +104,7 @@ class UnsafeSelectRewriter : public StmtExprMutator { bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return Call(op->dtype, builtin::if_then_else(), + return Call(op->dtype, intrinsic::tvm_if_then_else, {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index c35caf54db4a9..0684189c88e89 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -239,7 +238,7 @@ class HostDeviceSplitter : public StmtMutator { call_args.push_back(ext); } return Evaluate( - Call(DataType::Int(32), builtin::tvm_call_packed(), call_args, CallNode::Intrinsic)); + Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); } // target ir module diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 24f8b756974c8..20cc6402135f8 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -181,10 +181,10 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { } void StorageAccessVisitor::VisitExpr_(const CallNode* op) { - if (op->op.same_as(builtin::address_of())) { + if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); StmtExprVisitor::VisitExpr_(l); - } else if (op->op.same_as(builtin::tvm_access_ptr())) { + } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -211,7 +211,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } StmtExprVisitor::VisitExpr_(op); - } else if (op->op.same_as(builtin::tvm_storage_sync())) { + } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 30805508144d3..e29d978e0d42b 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include @@ -46,6 +45,7 @@ namespace tvm { namespace tir { +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; @@ -101,7 +101,7 @@ class StorageFlattener : public StmtExprMutator { } else if (op->attr_key == attr::buffer_dim_align) { auto buffer = Downcast(op->node); const CallNode* tuple = op->value.as(); - CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); + CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); auto& vinfo = dim_align_[buffer]; int dim = tuple->args[0].as()->value; if (static_cast(dim) >= vinfo.size()) { @@ -322,9 +322,9 @@ class StorageFlattener : public StmtExprMutator { } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = - Call(DataType::Handle(), builtin::address_of(), {load}, CallNode::PureIntrinsic); + Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); PrimExpr prefetch = - Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}, CallNode::Intrinsic); + Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); @@ -392,7 +392,7 @@ class StorageFlattener : public StmtExprMutator { const BufferNode* target = arr[1].as(); const CallNode* tuple = op->value.as(); CHECK(buffer && target); - CHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); + CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); auto key = GetRef(target); auto it = buf_map_.find(key); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d7a258cffe305..283ab0f6f7035 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include @@ -132,7 +131,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { } } void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::address_of())) { + if (op->is_intrinsic(intrinsic::tvm_address_of)) { const LoadNode* l = op->args[0].as(); this->VisitExpr(l->index); } else { @@ -388,7 +387,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); @@ -404,7 +403,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return Call(op->dtype, op->op, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + return Call(op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, op->call_type); } else { return StmtExprMutator::VisitExpr_(op); @@ -912,7 +911,7 @@ class VectorAllocRewriter : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { DataType dtype = op->args[0].dtype(); const VarNode* buffer = op->args[1].as(); UpdateTypeMap(buffer, dtype); diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 1b3b3c44ff9ca..493aa516fbd72 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -53,8 +53,8 @@ class FragmentGetter : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); - if (op->op.same_as(builtin::tvm_load_matrix_sync()) || - op->op.same_as(builtin::tvm_store_matrix_sync())) { + if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync) || + op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) { // Get shape and layout information from load and store intrinsic CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var = op->args[0].as(); @@ -89,7 +89,7 @@ class FragmentGetter : public StmtExprVisitor { } fragments[buffer_var] = info; } - } else if (op->op.same_as(builtin::tvm_fill_fragment())) { + } else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { // Get shape information from fill intrinsic CHECK_EQ(op->args.size(), 6U); const VarNode* buffer_var = op->args[0].as(); @@ -141,7 +141,7 @@ class FragmentChecker : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) { + if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) { CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index cdd9377e00d6b..612efb0923951 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include #include @@ -210,7 +209,7 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. @@ -260,7 +259,7 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); CHECK_EQ(op->args.size(), 5U); @@ -300,7 +299,7 @@ class ThreadSyncInserter : public StmtExprMutator { CHECK(op != nullptr); Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; Stmt prep = - Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs, CallNode::Intrinsic)); + Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; @@ -310,7 +309,7 @@ class ThreadSyncInserter : public StmtExprMutator { } rw_stats_.clear(); Stmt kinit = Evaluate( - Call(DataType::Int(32), builtin::tvm_global_barrier_kinit(), {}, CallNode::Intrinsic)); + Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); body = AttrStmt(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); @@ -333,7 +332,7 @@ class ThreadSyncInserter : public StmtExprMutator { } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } - return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(), + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, CallNode::Intrinsic)); } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 1a2ec502f6055..227aea2eb5754 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -23,10 +23,8 @@ // Loop vectorizer as in Halide pipeline. #include #include -#include #include #include -#include #include #include @@ -214,18 +212,15 @@ class Vectorizer : public StmtExprMutator { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}, op->call_type); + return Call(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); } } // Call PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::if_then_else())) { + if (op->name == intrinsic::tvm_if_then_else) { return MutateIfThenElseExpr_(op); } - auto* op_ptr = op->op.as(); - bool vectorizable = op_ptr && op_vectorizable_.get(GetRef(op_ptr), false); - - if (!vectorizable) { + if (!op->is_vectorizable()) { // Cannot vectorize this op Array new_args; for (auto arg : op->args) { @@ -239,7 +234,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype, op->op, new_args, op->call_type); + return Call(op->dtype, op->name, new_args, op->call_type); } } else { int lane = 0; @@ -248,7 +243,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return Call(op->dtype.with_lanes(lane), op->op, new_args, op->call_type); + return Call(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); } } } @@ -385,9 +380,6 @@ class Vectorizer : public StmtExprMutator { bool need_scalarize_{false}; // The lets std::unordered_map lets_; - // vectorizable property - OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); - // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. Array MutateArray(Array arr, int* p_lanes) { diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index ce50ed0c45f72..8dae79929fe8a 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -193,10 +192,9 @@ TEST(IRF, StmtMutator) { } { - auto body = Evaluate(Call(DataType::Int(32), builtin::call_extern(), {StringImm("xyz"), x + 1}, - CallNode::Extern)); + auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); auto res = v(std::move(body)); - CHECK(res.as()->value.as()->args[1].same_as(x)); + CHECK(res.as()->value.as()->args[0].same_as(x)); } { Stmt body = fmakealloc(); diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 1e4fe6b668307..c4ac042bdb221 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -853,8 +853,8 @@ def test_duplicate_adt_cons_defn(): def test_duplicate_global_var(): parse_text( """ - def @id[A](%%x: A) -> A { x } - def @id[A](%%x: A) -> A { x } + def @id[A](%x: A) -> A { x } + def @id[A](%x: A) -> A { x } """ ) diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 2fbc82f06ccf9..525cd6c30736d 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -202,8 +202,7 @@ def test_reduce_combiner_simplify(): assert tvm.ir.structural_equal(lhs, rhs) # Test that components with side effects are not removed - dummy = tvm.ir.GlobalVar("dummy") - side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs, tvm.tir.Call.Intrinsic) + side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index 18a98eed06734..0f00e08f9192d 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -98,7 +98,7 @@ def test_reinterpret(): nn = 1024 n = tvm.runtime.convert(nn) A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "tir.reinterpret", A(*i)), name='B') + B = te.compute(A.shape, lambda *i: tvm.tir.call_pure_intrin("float32", "reinterpret", A(*i)), name='B') s = te.create_schedule(B.op) def check_c(): diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index a6a231564033a..0b415b0de6bad 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -29,12 +29,12 @@ def test_llvm_intrin(): n = tvm.runtime.convert(4) A = ib.pointer("float32", name="A") args = [ - tvm.tir.call_pure_intrin("handle", "tir.address_of", A[0]), + tvm.tir.call_pure_intrin("handle", "tvm_address_of", A[0]), 0, 3, 1 ] ib.emit(tvm.tir.Evaluate( tvm.tir.Call( - "int32", "tir.prefetch", args, tvm.tir.Call.Intrinsic))) + "int32", "prefetch", args, tvm.tir.Call.Intrinsic))) body = ib.get() mod = tvm.IRModule.from_expr( @@ -738,20 +738,20 @@ def _transform(f, *_): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) def np_float2np_bf16(arr): - ''' Convert a numpy array of float to a numpy array + ''' Convert a numpy array of float to a numpy array of bf16 in uint16''' orig = arr.view('> y) == '@tir.shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x & y) == '@tir.bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x | y) == '@tir.bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(x ^ y) == '@tir.bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' - assert str(10 & x) == '@tir.bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 | x) == '@tir.bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 ^ x) == '@tir.bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 >> x) == '@tir.shift_right(10, x: int32, dtype=int32, type="pure_intrin")' - assert str(10 << x) == '@tir.shift_left(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")' assert str(10 % x) == 'floormod(10, x: int32)' - - assert str(~x) == '@tir.bitwise_not(x: int32, dtype=int32, type="pure_intrin")' + assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -240,10 +239,10 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == '@tir.isnan(x: float32, dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == '@tir.isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' + assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' z = te.var('z', 'int32') assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') diff --git a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py index 61accf2716317..38529e927d529 100644 --- a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py +++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py @@ -26,21 +26,20 @@ def test_ir_transform(): ib.emit(tvm.tir.call_extern("int32", "TestB", x)) ib.emit(tvm.tir.call_extern("int32", "TestC", x)) body = ib.get() - builtin_call_extern = tvm.ir.Op.get("tir.call_extern") def preorder(op): - if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestC": + if op.name == "TestC": return tvm.tir.const(0, "int32") return None def postorder(op): assert isinstance(op, tvm.tir.Call) - if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestA": - return tvm.tir.call_extern("int32", "TestB", op.args[1] + 1) + if op.name == "TestA": + return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) return op body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) - assert stmt_list[0].value.args[1].args[0].value == "TestB" + assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 55a6819aeced7..77a06022ac701 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -116,19 +116,19 @@ def test_legalize(): def to32(v): uint32_v = topi.cast(v, "uint32") uint32_v = tvm.tir.call_pure_intrin( - "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32")) - return tvm.tir.call_pure_intrin("float32", "tir.reinterpret", uint32_v) + "uint32", "shift_left", uint32_v, tvm.tir.const(16, "uint32")) + return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v) def to16(v): - uint32_v = tvm.tir.call_pure_intrin("uint32", "tir.reinterpret", v) + uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v) rounding_bias = tvm.tir.call_pure_intrin( - "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) + "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) rounding_bias = tvm.tir.call_pure_intrin( - "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) + "uint32", "bitwise_and", rounding_bias, tvm.tir.const(1, "uint32")) rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") uint32_v = uint32_v + rounding_bias uint32_v = tvm.tir.call_pure_intrin( - "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32")) + "uint32", "shift_right", uint32_v, tvm.tir.const(16, "uint32")) return topi.cast(uint32_v, 'uint16') def check(fcompute_before, fcompute_after): diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index d7a25ca0156e5..29a3303196229 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -22,7 +22,7 @@ def test_for(): def device_context(dev_id): ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) return tvm.tir.Call( - "handle", "tir.tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) + "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) ib = tvm.tir.ir_builder.create() n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py index 8469bc953b712..f6583493d6463 100644 --- a/tests/python/unittest/test_tir_transform_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -17,14 +17,6 @@ import tvm from tvm import te -# register the ops -tvm.ir.register_op_attr("tir.cop.coproc_sync", "TGlobalSymbol", "coproc_sync") -tvm.ir.register_op_attr("tir.cop.coproc_read_barrier", "TGlobalSymbol", "coproc_readb") -tvm.ir.register_op_attr("tir.cop.coproc_write_barrier", "TGlobalSymbol", "coproc_writeb") -tvm.ir.register_op_attr("tir.cop.coproc_dep_push", "TGlobalSymbol", "coproc_dep_push") -tvm.ir.register_op_attr("tir.cop.coproc_dep_pop", "TGlobalSymbol", "coproc_dep_pop") - - def test_coproc_sync(): @tvm.register_func("tvm.info.mem.global.cache") def meminfo_cache(): @@ -34,7 +26,6 @@ def meminfo_cache(): max_simd_bits=32, max_num_bits=128, head_address=tvm.tir.call_extern("handle", "global_cache")) - ib = tvm.tir.ir_builder.create() n = te.size_var("n") cp = te.thread_axis((0, 1), "cop") @@ -52,11 +43,10 @@ def meminfo_cache(): body = stmt.body.body.body blist = tvm.tir.stmt_list(body) - - assert(blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier"))) + assert(blist[1].value.name == "cop.coproc_read_barrier") assert(blist[1].value.args[3].value == 80) - assert(blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync"))) - assert(blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier"))) + assert(blist[-2].value.name == "cop.coproc_sync") + assert(blist[-1].value.name == "cop.coproc_write_barrier") assert(blist[-1].value.args[3].value == 10) @@ -116,9 +106,9 @@ def __check_list(tvm_array, py_list): slist = tvm.tir.stmt_list(slist[-1]) pop_st = slist[0].body[0] - assert(push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push"))) + assert(push_st.value.name == "cop.coproc_dep_push") assert(__check_list(push_st.value.args, [2,3])) - assert(pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop"))) + assert(pop_st.value.name == "cop.coproc_dep_pop") assert(__check_list(pop_st.value.args, [2,3])) diff --git a/tests/python/unittest/test_tir_transform_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py index cf5863204bfe7..0b6b167c86602 100644 --- a/tests/python/unittest/test_tir_transform_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -56,7 +56,7 @@ def test_double_buffer(): f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): + if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 4964039a4c142..c0789c654fbfc 100644 --- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -36,7 +36,7 @@ def get_vthread(name): bbuffer = tvm.tir.decl_buffer((m,), dtype=B.dtype, data=B.asobject()) ib.emit(tvm.tir.call_extern("int32", "Run", bbuffer.access_ptr("r"), - tvm.tir.call_pure_intrin("int32", "tir.tvm_context_id"))) + tvm.tir.call_pure_intrin("int32", "tvm_context_id"))) C[i * nthread + tx] = B[i] + 1 return ib.get() diff --git a/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py index 9f1104dcc5121..229c11b783a63 100644 --- a/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py +++ b/tests/python/unittest/test_tir_transform_rewrite_unsafe_select.py @@ -39,10 +39,8 @@ def test_rewrite_Select(): mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a))) aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value - builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else") - - assert yy.op.same_as(builtin_if_then_else) - assert yy.op.same_as(builtin_if_then_else) + assert yy.name == "tvm_if_then_else" + assert zz.name == "tvm_if_then_else" assert isinstance(aa, tvm.tir.Select) diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 468867a425cde..5fea580fbf5c1 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -125,7 +125,7 @@ def test_flatten_double_buffer(): count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): + if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 3ff6804cf7e0c..783b66983c48a 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -49,7 +49,7 @@ def test_thread_storage_sync(): cuda_target = tvm.target.create("cuda") f = tvm.tir.transform.ThreadSync("shared")(mod)["test_kernel0"] body_list = tvm.tir.stmt_list(f.body.body.body.body) - assert(body_list[1].value.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync"))) + assert(body_list[1].value.name == "tvm_storage_sync") diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index a69c9d36c6938..d7124b6b7e89c 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -117,7 +117,7 @@ def test_vectorize_if_then_else(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") with ib.for_range(0, 4, for_type="vectorize") as i: - A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", + A[i] = tvm.tir.call_intrin("float32", "tvm_if_then_else", i > 0, A[i] + 1, A[i]) stmt = ib.get() @@ -132,7 +132,7 @@ def test_vectorize_if_then_else(): A = ib.pointer("float32", name="A") with ib.for_range(0, n) as k: with ib.for_range(0, 4, for_type="vectorize") as i: - A[k * 4 + i] = tvm.tir.call_intrin("float32", "tir.if_then_else", + A[k * 4 + i] = tvm.tir.call_intrin("float32", "tvm_if_then_else", k > 0, A[k * 4 + i], 0) stmt = ib.get() diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index 7068b95bec6cf..b84fbc7722a14 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -25,7 +25,6 @@ #define TOPI_DETAIL_EXTERN_H_ #include -#include #include #include @@ -112,11 +111,11 @@ inline Array make_extern(const Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), + auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), + strides = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; @@ -127,7 +126,7 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(DataType::Int(32), static_cast(buf->shape.size())), make_const(buf->dtype, 0), buf->elem_offset}; - return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args, + return tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, pack_args, tvm::tir::CallNode::CallType::Intrinsic); } @@ -141,7 +140,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args, + return tvm::tir::Call(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, tvm::tir::CallNode::CallType::Intrinsic); } diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 0ec7e4d212bf4..a92d21c27afe6 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -25,7 +25,6 @@ #define TOPI_ELEMWISE_H_ #include -#include #include #include @@ -310,8 +309,7 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te return compute( x->shape, [&](const Array& i) { - return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)}, - tvm::tir::CallNode::PureIntrinsic); + return tvm::tir::Call(type, "reinterpret", {x(i)}, tvm::tir::CallNode::PureIntrinsic); }, name, tag); } diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index f035251a8c299..ac1ac45c1b38c 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -231,10 +231,8 @@ def _instr(index): cnts = tvm.tir.popcount(w_ & x_) - tvm.tir.popcount(~w_ & x_) else: cnts = tvm.tir.popcount(w_ & x_) - upper_half = tvm.tir.call_pure_intrin( - half_dtype, 'tir.vectorhigh', cnts) - lower_half = tvm.tir.call_pure_intrin( - half_dtype, 'tir.vectorlow', cnts) + upper_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorhigh', cnts) + lower_half = tvm.tir.call_pure_intrin(half_dtype, 'vectorlow', cnts) cnts8[i] = upper_half + lower_half for i in range(m//2): cnts4[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, @@ -243,7 +241,7 @@ def _instr(index): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( - full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) + full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) out = tvm.tir.call_llvm_intrin( return_dtype, vpadalu, @@ -263,7 +261,7 @@ def _instr(index): cnts2[i] = tvm.tir.call_llvm_intrin(half_dtype, vpadd, args_2, cnts4[i*2], cnts4[i*2+1]) cnts = tvm.tir.call_pure_intrin( - full_dtype, 'tir.vectorcombine', cnts2[0], cnts2[1]) + full_dtype, 'vectorcombine', cnts2[0], cnts2[1]) shifted_cnts = cnts << tvm.tir.const(bw+bx, pack_dtype) out = tvm.tir.call_llvm_intrin( return_dtype, vpadalu, diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index da9c71a5346b7..bab91578e77ee 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -86,11 +86,11 @@ def _instr(index): dtype_c = '%s32x%d' % (dtype, int32_lanes) a_int8 = ins[0].vload([0], dtype_a) - re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'tir.reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('%s32' % dtype, 'reinterpret', a_int8) # broadcast a vec_ai32 = re_int32.astype(dtype_c) - vec_a = tvm.tir.call_pure_intrin(dtype_b, 'tir.reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin(dtype_b, 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], dtype_b) vec_c = outs[0].vload([0], dtype_c) diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index c98d7e99d3ee5..f2c1143b5fb82 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -38,10 +38,9 @@ def cuda_atomic_add_rule(op): tvm.target.intrin.register_intrin_rule( "cuda", "atomic_add", cuda_atomic_add_rule, override=True) -tvm.ir.register_op_attr("tir.atomic_add", "TVectorizable", False) def atomic_add(x, y): - return tvm.tir.call_pure_intrin(y.dtype, "tir.atomic_add", x, y) + return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) def get_valid_counts_ir(data, valid_count, out, out_indices, @@ -114,7 +113,7 @@ def get_valid_counts_ir(data, valid_count, out, out_indices, with ib.if_scope( tvm.tir.all(data[tid * elem_length + score_index] > score_threshold, tvm.tir.any(id_index < 0, data[tid * elem_length + id_index] >= 0))): - atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tir.address_of", + atomic_add_return[0] = atomic_add(tvm.tir.call_pure_intrin("handle", "tvm_address_of", valid_count[i]), one_count) with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = data[tid * elem_length + k] diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index 5b7e0905de63b..f713bb216808a 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -185,7 +185,7 @@ def argsort_ir(data_buf, out_index_buf): temp_index[0] = index_out[offset] index_out[offset] = index_out[offset + 1] index_out[offset + 1] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) return ib.get() @@ -246,7 +246,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5) with ib.if_scope(iou > nms_threshold): p_out[base_idx + i] = True - ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) return ib.get() diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 7181d57216845..ddae2bd96135f 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -115,7 +115,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): if indices_out is not None: indices_out[base_idx + tid * axis_mul_after] = \ tvm.tir.generic.cast(tid, indices_out.dtype) - ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) idxd = tvm.tir.indexdiv @@ -143,7 +143,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): temp_index[0] = indices_out[offset] indices_out[offset] = indices_out[offset + axis_mul_after] indices_out[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) @@ -235,7 +235,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): temp_index[0] = output[offset] output[offset] = output[offset + axis_mul_after] output[offset + axis_mul_after] = temp_index[0] - ib.emit(tvm.tir.Call(None, 'tir.tvm_storage_sync', + ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), tvm.tir.Call.Intrinsic)) diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index c2b7d250293df..3941c00cc4646 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -100,7 +100,7 @@ def intrin_func(ins, outs): BC = outs[0] row = wmma_m * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_k - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() @@ -128,7 +128,7 @@ def intrin_func(ins, outs): BC = outs[0] row = wmma_n * wmma_k warp_index = BC.elem_offset // row + BC.elem_offset % row // wmma_n - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', BC.data, wmma_m, wmma_n, wmma_k, warp_index, BA.access_ptr('r'), strides_from[0], layout)) return ib.get() @@ -156,7 +156,7 @@ def intrin_func(ins, outs): BC = outs[0] row = wmma_m * wmma_n warp_index = BA.elem_offset // row + BA.elem_offset % row // wmma_n - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', BA.data, wmma_m, wmma_n, wmma_k, warp_index, BC.access_ptr('w'), strides_dst[0], 'row_major')) return ib.get() @@ -207,14 +207,13 @@ def warp_idnex(offset, row, col): def init(): ib = tvm.tir.ir_builder.create() ib.emit( - tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', - BC.data, wmma_m, wmma_n, wmma_k, + tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, wmma_m, wmma_n, wmma_k, warp_index_C, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', BC.data, warp_index_C, BA.data, warp_index_A, BB.data, warp_index_B, diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index 31de70e92f182..ee8d83dbef075 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -88,9 +88,9 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_b = ins[1].vload([0, 0], "int8x64") vec_one = tvm.tir.const(1, "int16x32") pair_reduction = tvm.tir.call_llvm_intrin('int16x32', @@ -174,9 +174,9 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x2") - re_int16 = tvm.tir.call_pure_intrin('int16', 'tir.reinterpret', a_int8) + re_int16 = tvm.tir.call_pure_intrin('int16', 'reinterpret', a_int8) vec_ai16 = re_int16.astype('int16x32') - vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai16) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai16) for i in range(4): vec_b = ins[1].vload([i*32, 0], "int8x64") @@ -254,7 +254,7 @@ def _instr(index): return ib.get() a_int8 = ins[0].vload([0], "uint8x4") - re_int32 = tvm.tir.call_pure_intrin('int32', 'tir.reinterpret', a_int8) + re_int32 = tvm.tir.call_pure_intrin('int32', 'reinterpret', a_int8) vec_ai32 = re_int32.astype('int32x16') vec_b = ins[1].vload([0, 0], "int8x64") @@ -262,7 +262,7 @@ def _instr(index): llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) if llvm_id != 0: # VNNI is available for current LLVM version - vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'tir.reinterpret', vec_b) + vec_bi32 = tvm.tir.call_pure_intrin('int32x16', 'reinterpret', vec_b) vec_zero = tvm.tir.const(0, "int32x16") quad_reduction = tvm.tir.call_llvm_intrin('int32x16', 'llvm.x86.avx512.vpdpbusd.512', @@ -270,7 +270,7 @@ def _instr(index): vec_zero, vec_ai32, vec_bi32) else: # Fall back to the normal AVX512 - vec_a = tvm.tir.call_pure_intrin('int8x64', 'tir.reinterpret', vec_ai32) + vec_a = tvm.tir.call_pure_intrin('int8x64', 'reinterpret', vec_ai32) vec_one = tvm.tir.const(1, "int16x32") pair_reduction = tvm.tir.call_llvm_intrin('int16x32', 'llvm.x86.avx512.pmaddubs.w.512', diff --git a/topi/tests/python/test_topi_basic.py b/topi/tests/python/test_topi_basic.py index a83ff50bd5b19..13f1463da7ff5 100644 --- a/topi/tests/python/test_topi_basic.py +++ b/topi/tests/python/test_topi_basic.py @@ -34,7 +34,7 @@ def test_ewise(): def test_apply(func, name): B = func(A) assert tuple(B.shape) == tuple(A.shape) - assert B.op.body[0].op.name == "tir." + name + assert B.op.body[0].name == name test_apply(topi.exp, "exp") test_apply(topi.erf, "erf") diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 6f1e8588fd7c6..ea980833ae204 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -50,11 +50,11 @@ def test_apply( B = func(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: - assert B.op.body[0].op.name == "tir." + name + assert B.op.body[0].name == name a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 # avoid round check too close to boundary if check_round: - a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-4 + a_np += ((np.abs(np.fmod(a_np, 1)) - 0.5) < 1e-6) * 1e-5 b_np = f_numpy(a_np) def check_device(device): @@ -89,7 +89,7 @@ def test_isnan( B = topi.isnan(A) assert tuple(B.shape) == tuple(A.shape) if not skip_name_check: - assert B.op.body[0].op.name == "tir.isnan" + assert B.op.body[0].name == "isnan" a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan # avoid round check too close to boundary diff --git a/tutorials/language/intrin_math.py b/tutorials/language/intrin_math.py index 65bfd4c38681d..146263dab5d76 100644 --- a/tutorials/language/intrin_math.py +++ b/tutorials/language/intrin_math.py @@ -100,15 +100,12 @@ def my_cuda_math_rule(op): """Customized CUDA intrinsic lowering rule""" assert isinstance(op, tvm.tir.Call) - name = op.op.name - assert name.startswith("tir.") - dispatch_name = name[4:] if op.dtype == "float32": # call float function - return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0]) + return tvm.tir.call_pure_extern("float32", "%sf" % op.name, op.args[0]) elif op.dtype == "float64": # call double function - return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0]) + return tvm.tir.call_pure_extern("float32", op.name, op.args[0]) else: # cannot do translation, return self. return op @@ -135,7 +132,7 @@ def my_cuda_math_rule(op): def mylog(x): """customized log intrinsic function""" - return tvm.tir.call_pure_intrin(x.dtype, "tir.mylog", x) + return tvm.tir.call_pure_intrin(x.dtype, "mylog", x) def my_cuda_mylog_rule(op): @@ -147,8 +144,7 @@ def my_cuda_mylog_rule(op): else: return op -# new op registration is triggered by registering an attribute of the op -tvm.ir.register_op_attr("tir.mylog", "TVectorizable", True) + tvm.target.register_intrin_rule("cuda", "mylog", my_cuda_mylog_rule, override=True) n = te.var("n") diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py index 4b2823c08d036..cd40a91ac6c89 100644 --- a/tutorials/optimize/opt_conv_tensorcore.py +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -163,7 +163,7 @@ def intrin_func(ins, outs): BA = ins[0] BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_load_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_load_matrix_sync', BC.data, n, n, n, BC.elem_offset // 256, BA.access_ptr('r'), n, 'row_major')) return ib.get() @@ -190,12 +190,12 @@ def intrin_func(ins, outs): def init(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) + ib.emit(tvm.tir.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) return ib.get() def update(): ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_mma_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_mma_sync', BC.data, BC.elem_offset // 256, BA.data, BA.elem_offset // 256, BB.data, BB.elem_offset // 256, @@ -218,7 +218,7 @@ def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() BA = ins[0] BC = outs[0] - ib.emit(tvm.tir.call_intrin('handle', 'tir.tvm_store_matrix_sync', + ib.emit(tvm.tir.call_intrin('handle', 'tvm_store_matrix_sync', BA.data, n, n, n, BA.elem_offset // 256, BC.access_ptr('w'), n, 'row_major')) return ib.get() diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 947c583ed55f8..e68f098ba53f1 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -77,9 +77,9 @@ class DevContext(object): def __init__(self, env): self.vta_axis = te.thread_axis("vta") self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp") - ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle") + ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle") self.command_handle = tvm.tir.Call( - "handle", "tir.tvm_thread_context", [ctx], + "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) self.DEBUG_NO_SYNC = False env._dev_ctx = self @@ -298,7 +298,6 @@ def coproc_sync(op): tvm.runtime.const(1<<31, dtype="uint32")) - @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") def coproc_dep_push(op): return tvm.tir.call_extern( @@ -314,15 +313,6 @@ def coproc_dep_pop(op): get_env().dev.command_handle, op.args[0], op.args[1]) -# register a dummy into to trigger registration of the ops -# change the info to lowering rule later. -tvm.ir.register_op_attr("tir.vta.coproc_sync", "TVectorizable", False) -tvm.ir.register_op_attr("tir.vta.coproc_dep_push", "TVectorizable", False) -tvm.ir.register_op_attr("tir.vta.coproc_dep_pop", "TVectorizable", False) - -tvm.ir.register_op_attr("tir.vta.uop_push", "TGlobalSymbol", "VTAUopPush") -tvm.ir.register_op_attr("tir.vta.command_handle", "TGlobalSymbol", "VTATLSCommandHandle") - def _init_env(): """Initialize the default global env""" diff --git a/vta/python/vta/intrin.py b/vta/python/vta/intrin.py index 897bbcba4cd35..8532ffa318b56 100644 --- a/vta/python/vta/intrin.py +++ b/vta/python/vta/intrin.py @@ -82,16 +82,16 @@ def instr(index): irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) if index in (0, 2): - irb.emit(tvm.tir.call_intrin( - "int32", "tir.vta.uop_push", + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopPush", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), dwgt.access_ptr("r", "int32"), 0, 0, 0)) else: - irb.emit(tvm.tir.call_intrin( - "int32", "tir.vta.uop_push", + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopPush", 0, 1, dout.access_ptr("rw", "int32"), 0, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index e92b178a5be6c..207f784b5885f 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -59,12 +59,11 @@ def _fold_outermost_loop(body): loop_var = stmt.loop_var gemm_offsets = [None, None, None] fail = [False] - builtin_uop_push = tvm.ir.Op.get("tir.vta.uop_push") def _post_order(op): assert isinstance(op, tvm.tir.Call) base_args = 2 - if op.op.same_as(builtin_uop_push): + if op.name == "VTAUopPush": args = [] args += op.args[:base_args] for i in range(3): @@ -82,8 +81,8 @@ def _post_order(op): gemm_offsets[i] = m[0] args.append(m[1]) args += op.args[base_args+3:] - return tvm.tir.call_intrin("int32", builtin_uop_push, *args) - if op.op.name not in ("tir.vta.command_handle", "tir.tvm_thread_context"): + return tvm.tir.call_extern("int32", "VTAUopPush", *args) + if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): raise RuntimeError("unexpected op %s" % op) return op @@ -644,7 +643,7 @@ def _do_fold(op): dev = env.dev irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push", + irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", 0, 1, dout.access_ptr("rw", "int32"), 0, 0, @@ -659,7 +658,7 @@ def _do_fold(op): tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_buffer], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) return inner else: conv_call, data_call, kernel_call = calls[-3:] @@ -679,7 +678,7 @@ def _do_fold(op): irb.scope_attr( dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_intrin("int32", "tir.vta.uop_push", + irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", 0, 0, dout.access_ptr("rw", "int32"), dinp.access_ptr("r", "int32"), @@ -692,19 +691,19 @@ def _do_fold(op): 1, 0, 1, 0, env.BLOCK_OUT) inner = tvm.tir.AttrStmt( [dout, res_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) args = kernel_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dwgt, kernel_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) args = data_call.indices tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_IN) inner = tvm.tir.AttrStmt( [dinp, pad_data_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tir.tvm_tuple', *tpl), inner) + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) return inner return None @@ -834,11 +833,11 @@ def _flatten_loop(src_coeff, dst_coeff, extents): lhs = loop_body.value.a rhs = loop_body.value.b elif isinstance(loop_body.value, tvm.tir.Call): - if loop_body.value.op.name == 'tir.shift_left': + if loop_body.value.name == 'shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = analyzer.simplify(-loop_body.value.args[1]) - elif loop_body.value.op.name == 'tir.shift_right': + elif loop_body.value.name == 'shift_right': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] rhs = loop_body.value.args[1] @@ -943,8 +942,8 @@ def _flatten_loop(src_coeff, dst_coeff, extents): "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0)) use_imm = int(use_imm) - irb.emit(tvm.tir.call_intrin( - "int32", "tir.vta.uop_push", + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopPush", 1, 0, dst_coeff[len(dst_coeff)-1], src_coeff[len(src_coeff)-1],