Skip to content

Commit

Permalink
Add helper functions to query properties of the lowered Target (#8192)
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson committed Jul 29, 2024
1 parent 5d1472f commit 122ff30
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ SOURCE_FILES = \
StripAsserts.cpp \
Substitute.cpp \
Target.cpp \
TargetQueryOps.cpp \
Tracing.cpp \
TrimNoOps.cpp \
Tuple.cpp \
Expand Down Expand Up @@ -778,6 +779,7 @@ HEADER_FILES = \
StripAsserts.h \
Substitute.h \
Target.h \
TargetQueryOps.h \
Tracing.h \
TrimNoOps.h \
Tuple.h \
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ set(HEADER_FILES
StripAsserts.h
Substitute.h
Target.h
TargetQueryOps.h
Tracing.h
TrimNoOps.h
Tuple.h
Expand Down Expand Up @@ -346,6 +347,7 @@ set(SOURCE_FILES
StripAsserts.cpp
Substitute.cpp
Target.cpp
TargetQueryOps.cpp
Tracing.cpp
TrimNoOps.cpp
Tuple.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,11 @@ const char *const intrinsic_op_names[] = {
"sorted_avg",
"strict_float",
"stringify",
"target_arch_is",
"target_bits",
"target_has_feature",
"target_natural_vector_size",
"target_os_is",
"undef",
"unreachable",
"unsafe_promise_clamped",
Expand Down
7 changes: 7 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,13 @@ struct Call : public ExprNode<Call> {
sorted_avg,
strict_float,
stringify,

target_arch_is,
target_bits,
target_has_feature,
target_natural_vector_size,
target_os_is,

undef,
unreachable,
unsafe_promise_clamped,
Expand Down
20 changes: 20 additions & 0 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2735,4 +2735,24 @@ Expr concat_bits(const std::vector<Expr> &e) {
return Call::make(t.with_bits(t.bits() * (int)e.size()), Call::concat_bits, e, Call::Intrinsic);
}

Expr target_arch_is(Target::Arch arch) {
return Call::make(Bool(), Call::target_arch_is, {Expr((int)arch)}, Call::PureIntrinsic);
}

Expr target_os_is(Target::OS os) {
return Call::make(Bool(), Call::target_os_is, {Expr((int)os)}, Call::PureIntrinsic);
}

Expr target_bits() {
return Call::make(Int(32), Call::target_bits, {}, Call::PureIntrinsic);
}

Expr target_has_feature(Target::Feature feat) {
return Call::make(Bool(), Call::target_has_feature, {Expr((int)feat)}, Call::PureIntrinsic);
}

Expr target_natural_vector_size(Type t) {
return Call::make(Int(32), Call::target_natural_vector_size, {make_zero(t.element_of())}, Call::PureIntrinsic);
}

} // namespace Halide
41 changes: 41 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <map>

#include "Expr.h"
#include "Target.h"
#include "Tuple.h"

namespace Halide {
Expand Down Expand Up @@ -1689,6 +1690,46 @@ Expr rounding_mul_shift_right(Expr a, Expr b, Expr q);
Expr rounding_mul_shift_right(Expr a, Expr b, int q);
//@}

/** Return a boolean Expr for the corresponding field of the Target
* being used during lowering; they can be useful in writing library
* code without having to plumb a Target through call sites, so that you
* can do things like
\code
Expr e = select(target_arch_is(Target::ARM), something, something_else);
\endcode
*/
//@{
Expr target_arch_is(Target::Arch arch);
Expr target_os_is(Target::OS os);
Expr target_has_feature(Target::Feature feat);
//@}

/** Return the bit width of the Target used during lowering; this can be useful
* in writing library code without having to plumb a Target through call sites, so that you
* can do things like
\code
Expr e = select(target_bits() == 32, something, something_else);
\endcode
*/
Expr target_bits();

/** Return the natural vector width for the given Type for the Target
* being used during lowering; this can be useful in writing library
* code without having to plumb a Target through call sites, so that you
* can do things like
\code
f.vectorize(x, target_natural_vector_size(Float(32)));
\endcode
*/
//@{
Expr target_natural_vector_size(Type t);
template<typename data_t>
Expr target_natural_vector_size() {
return target_natural_vector_size(type_of<data_t>());
}
//@}


} // namespace Halide

#endif
3 changes: 3 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include "StrictifyFloat.h"
#include "StripAsserts.h"
#include "Substitute.h"
#include "TargetQueryOps.h"
#include "Tracing.h"
#include "TrimNoOps.h"
#include "UnifyDuplicateLets.h"
Expand Down Expand Up @@ -144,6 +145,8 @@ void lower_impl(const vector<Function> &output_funcs,
// Create a deep-copy of the entire graph of Funcs.
auto [outputs, env] = deep_copy(output_funcs, build_environment(output_funcs));

lower_target_query_ops(env, t);

bool any_strict_float = strictify_float(env, t);
result_module.set_any_strict_float(any_strict_float);

Expand Down
54 changes: 54 additions & 0 deletions src/TargetQueryOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "TargetQueryOps.h"

#include "Function.h"
#include "IRMutator.h"
#include "IROperator.h"

namespace Halide {
namespace Internal {

namespace {

class LowerTargetQueryOps : public IRMutator {
const Target &t;

using IRMutator::visit;

Expr visit(const Call *call) override {
if (call->is_intrinsic(Call::target_arch_is)) {
Target::Arch arch = (Target::Arch)*as_const_int(call->args[0]);
return make_bool(t.arch == arch);
} else if (call->is_intrinsic(Call::target_has_feature)) {
Target::Feature feat = (Target::Feature)*as_const_int(call->args[0]);
return make_bool(t.has_feature(feat));
} else if (call->is_intrinsic(Call::target_natural_vector_size)) {
Expr zero = call->args[0];
return Expr(t.natural_vector_size(zero.type()));
} else if (call->is_intrinsic(Call::target_os_is)) {
Target::OS os = (Target::OS)*as_const_int(call->args[0]);
return make_bool(t.os == os);
} else if (call->is_intrinsic(Call::target_bits)) {
return Expr(t.bits);
}

return IRMutator::visit(call);
}

public:
LowerTargetQueryOps(const Target &t)
: t(t) {
}
};

} // namespace

void lower_target_query_ops(std::map<std::string, Function> &env, const Target &t) {
for (auto &iter : env) {
Function &func = iter.second;
LowerTargetQueryOps ltqo(t);
func.mutate(&ltqo);
}
}

} // namespace Internal
} // namespace Halide
24 changes: 24 additions & 0 deletions src/TargetQueryOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef HALIDE_TARGET_QUERY_OPS_H
#define HALIDE_TARGET_QUERY_OPS_H

/** \file
* Defines a lowering pass to lower all target_is() and target_has() helpers.
*/

#include <map>
#include <string>

namespace Halide {

struct Target;

namespace Internal {

class Function;

void lower_target_query_ops(std::map<std::string, Function> &env, const Target &t);

} // namespace Internal
} // namespace Halide

#endif
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ tests(GROUPS correctness
strict_float_bounds.cpp
strided_load.cpp
target.cpp
target_query.cpp
tiled_matmul.cpp
tracing.cpp
tracing_bounds.cpp
Expand Down
48 changes: 48 additions & 0 deletions test/correctness/target_query.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
// For simplicity, only run this test on hosts that we can predict.
Target t = get_host_target();
if (t.arch != Target::X86 || t.bits != 64 || t.os != Target::OSX) {
printf("[SKIP] This test only runs on x86-64-osx.\n");
return 0;
}

t = t.with_feature(Target::Debug);

// Full specification round-trip, crazy features
Target t1 = Target(Target::OSX, Target::X86, 64,
{Target::CUDA, Target::Debug});

Expr is_arm = target_arch_is(Target::ARM);
Expr is_x86 = target_arch_is(Target::X86);
Expr bits = target_bits();
Expr is_android = target_os_is(Target::Android);
Expr is_osx = target_os_is(Target::OSX);
Expr vec = target_natural_vector_size<float>();
Expr has_cuda = target_has_feature(Target::CUDA);
Expr has_vulkan = target_has_feature(Target::Vulkan);

Func f;
Var x;

f(x) = select(is_arm, 1, 0) +
select(is_x86, 2, 0) +
select(vec == 4, 4, 0) +
select(is_android, 8, 0) +
select(is_osx, 16, 0) +
select(bits == 32, 32, 0) +
select(bits == 64, 64, 0) +
select(has_cuda, 128, 0) +
select(has_vulkan, 256, 0);

Buffer<int> result = f.realize({1}, t1);

assert(result(0) == 2 + 4 + 16 + 64 + 128);

printf("Success!\n");
return 0;
}

0 comments on commit 122ff30

Please sign in to comment.