Skip to content

Commit

Permalink
Improved tvm::GetType for tvm_access_ptr and address_of
Browse files Browse the repository at this point in the history
These `Call` instances can return a
`PointerType(PrimType(pointee_dtype))` rather than a
`PrimType(DataType::Handle())`.
  • Loading branch information
Lunderberg committed Jul 3, 2023
1 parent e9196b2 commit 60da49a
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) {
return ptr->type_annotation;
}
}

if (auto* access = expr.as<tir::CallNode>()) {
if (access->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments";
auto type_annotation = Downcast<Call>(access->args[0]);
static auto builtin_op = Op::Get("tir.type_annotation");
ICHECK(type_annotation->op.same_as(builtin_op))
<< "Expected the first argument of builtin tvm_access_ptr() "
<< "to be a type annotation, but found " << type_annotation->op;
return PointerType(PrimType(type_annotation->dtype));
}
}

if (auto* address_of = expr.as<tir::CallNode>()) {
if (address_of->op.same_as(builtin::address_of())) {
ICHECK_EQ(address_of->args.size(), 1)
<< "Builtin address_of() expects a single argument, but received arguments "
<< address_of->args;
auto* address = address_of->args[0].as<BufferLoadNode>();
ICHECK(address)
<< "Builtin address_of() expects the argument to be a BufferLoad, but received argument "
<< address_of->args[0];

return PointerType(PrimType(address->dtype));
}
}
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
return GetTypeFromRuntimeDataType(dtype);
Expand Down

0 comments on commit 60da49a

Please sign in to comment.