Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor][FFI] Allow implicit conversion in TVM FFI to tvm::Bool #5907

Merged
merged 1 commit into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// common rule for RetValue and ArgValue
template <>
struct PackedFuncValueConverter<PrimExpr> {
// common rule for both RetValue and ArgValue.
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
Expand All @@ -500,6 +500,35 @@ struct PackedFuncValueConverter<PrimExpr> {
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kTVMArgInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
static tvm::Bool From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Bool(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kTVMArgInt) {
int v = val.operator int();
CHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
return Bool(static_cast<bool>(v));
}
return val.AsObjectRef<tvm::Bool>();
}
};

} // namespace runtime
} // namespace tvm
#endif // TVM_IR_EXPR_H_
20 changes: 0 additions & 20 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1147,26 +1147,6 @@ inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
} // namespace tir
} // namespace tvm

namespace tvm {
namespace runtime {
// Additional implementattion overloads for PackedFunc.

template <>
struct PackedFuncValueConverter<tvm::Integer> {
// common rule for RetValue and ArgValue
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};
} // namespace runtime
} // namespace tvm

namespace std {
template <>
struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
Expand Down