Skip to content

Commit

Permalink
Make "none" DataType explicit (apache#5491)
Browse files Browse the repository at this point in the history
* Make "none" DataType explicit

The None data type is created when converting an empty string to DataType.
Add functions to create it and recognize it. Convert it to the "void" LLVM
type in LLVM codegen.

* Rename "none" to "void"

* Map VoidType:Type -> Void:DataType in GetRuntimeDataType

* Map Void:DataType -> VoidType:Type in GetType
  • Loading branch information
Krzysztof Parzyszek authored and trevor-m committed Jun 18, 2020
1 parent 08ea4ee commit 0d87a0b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
20 changes: 17 additions & 3 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DataType {
}
/*! \return whether type is a handle type. */
bool is_handle() const {
return code() == DataType::kHandle;
return code() == DataType::kHandle && !is_void();
}
/*! \return whether type is a vector type. */
bool is_vector() const {
Expand All @@ -117,6 +117,10 @@ class DataType {
bool is_vector_bool() const {
return is_vector() && bits() == 1;
}
/*! \return whether type is a Void type. */
bool is_void() const {
return code() == DataType::kHandle && bits() == 0 && lanes() == 0;
}
/*!
* \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes.
Expand Down Expand Up @@ -211,6 +215,13 @@ class DataType {
static DataType Handle(int bits = 64, int lanes = 1) {
return DataType(kHandle, bits, lanes);
}
/*!
* \brief Construct a Void type.
* \return The constructed data type.
*/
static DataType Void() {
return DataType(kHandle, 0, 0);
}
/*!
* \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index.
Expand Down Expand Up @@ -335,6 +346,9 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (DataType(t).is_void()) {
return os << "void";
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
Expand All @@ -361,9 +375,9 @@ inline std::string DLDataType2String(DLDataType t) {

inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
// handle void type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
t = DataType::Void();
return t;
}
t.bits = 32; t.lanes = 1;
Expand Down
3 changes: 3 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
CHECK_EQ(dtype.lanes(), 1);
return t_void_p_;
}
if (dtype.is_void()) {
return t_void_;
}
llvm::Type* etype = nullptr;
if (dtype.is_int() || dtype.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, dtype.bits());
Expand Down
7 changes: 4 additions & 3 deletions src/tir/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ runtime::DataType GetRuntimeDataType(const Type& type) {
return n->dtype;
} else if (type.as<PointerTypeNode>()) {
return DataType::Handle();
} else if (IsVoidType(type)) {
return DataType::Void();
} else {
LOG(FATAL) << "Type " << type
<< " does not have a corresponding runtime::DataType";
Expand All @@ -57,9 +59,8 @@ Type GetType(const PrimExpr& expr) {
}
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
// These types already implies the specific type.
if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) {
return PrimType(dtype);
if (dtype.is_void()) {
return VoidType();
}
return PrimType(dtype);
}
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def test_llvm_intrin():
fcode = tvm.build(mod, None, "llvm")


def test_llvm_void_intrin():
ib = tvm.tir.ir_builder.create()
A = ib.pointer("uint8", name="A")
# Create an intrinsic that returns void.
x = tvm.tir.call_llvm_intrin('', 'llvm.va_start', tvm.tir.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
fcode = tvm.build(mod, None, "llvm")


def test_llvm_overloaded_intrin():
# Name lookup for overloaded intrinsics in LLVM 4- requires a name
# that includes the overloaded types.
Expand Down

0 comments on commit 0d87a0b

Please sign in to comment.