diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 44385d63263b..940818a04965 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -107,7 +107,7 @@ class DataType { } /*! \return whether type is a handle type. */ bool is_handle() const { - return code() == DataType::kHandle; + return code() == DataType::kHandle && !is_void(); } /*! \return whether type is a vector type. */ bool is_vector() const { @@ -117,6 +117,10 @@ class DataType { bool is_vector_bool() const { return is_vector() && bits() == 1; } + /*! \return whether type is a Void type. */ + bool is_void() const { + return code() == DataType::kHandle && bits() == 0 && lanes() == 0; + } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. @@ -211,6 +215,13 @@ class DataType { static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } + /*! + * \brief Construct a Void type. + * \return The constructed data type. + */ + static DataType Void() { + return DataType(kHandle, 0, 0); + } /*! * \brief Get the corresponding type of TVMShapeIndex. * \return The type of TVM shape index. @@ -335,6 +346,9 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { os << "bool"; return os; } + if (DataType(t).is_void()) { + return os << "void"; + } if (t.code < kTVMCustomBegin) { os << TypeCode2Str(t.code); } else { @@ -361,9 +375,9 @@ inline std::string DLDataType2String(DLDataType t) { inline DLDataType String2DLDataType(std::string s) { DLDataType t; - // handle None type + // handle void type if (s.length() == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + t = DataType::Void(); return t; } t.bits = 32; t.lanes = 1; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 86cd5a3acf61..74bda71cb8a8 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { CHECK_EQ(dtype.lanes(), 1); return t_void_p_; } + if (dtype.is_void()) { + return t_void_; + } llvm::Type* etype = nullptr; if (dtype.is_int() || dtype.is_uint()) { etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 4ad244ff02b2..6224321c70ff 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -38,6 +38,8 @@ runtime::DataType GetRuntimeDataType(const Type& type) { return n->dtype; } else if (type.as()) { return DataType::Handle(); + } else if (IsVoidType(type)) { + return DataType::Void(); } else { LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType"; @@ -57,9 +59,8 @@ Type GetType(const PrimExpr& expr) { } // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); - // These types already implies the specific type. - if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { - return PrimType(dtype); + if (dtype.is_void()) { + return VoidType(); } return PrimType(dtype); } diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index a7e1e57481a7..c6591721d247 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -43,6 +43,18 @@ def test_llvm_intrin(): fcode = tvm.build(mod, None, "llvm") +def test_llvm_void_intrin(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("uint8", name="A") + # Create an intrinsic that returns void. + x = tvm.tir.call_llvm_intrin('', 'llvm.va_start', tvm.tir.const(1, 'uint32'), A) + ib.emit(x) + body = ib.get() + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) + fcode = tvm.build(mod, None, "llvm") + + def test_llvm_overloaded_intrin(): # Name lookup for overloaded intrinsics in LLVM 4- requires a name # that includes the overloaded types.