Skip to content

Commit

Permalink
Add workaround for the const-or-not user_context issue (#635) (#7291)
Browse files Browse the repository at this point in the history
Add a workaround for the const-or-not user_context issue (#635)
  • Loading branch information
steven-johnson authored Jan 20, 2023
1 parent 2cc0468 commit c601e4e
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -1608,7 +1608,7 @@ $(FILTERS_DIR)/autograd_grad.a: $(BIN_DIR)/autograd.generator $(BIN_MULLAPUDI201
# all have the form nested_externs_*).
$(FILTERS_DIR)/nested_externs_%.a: $(BIN_DIR)/nested_externs.generator
@mkdir -p $(@D)
$(CURDIR)/$< -g nested_externs_$* $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-c_plus_plus_name_mangling
$(CURDIR)/$< -g nested_externs_$* $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-user_context-c_plus_plus_name_mangling

# Similarly, gpu_multi needs two different kernels to test compilation caching.
# Also requies user-context.
Expand Down
86 changes: 72 additions & 14 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,27 +1937,32 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na
}

if (output_kind != CPlusPlusFunctionInfoHeader) {
const auto emit_arg_decls = [&](const Type &ucon_type = Type()) {
const char *comma = "";
for (const auto &arg : args) {
stream << comma;
if (arg.is_buffer()) {
stream << "struct halide_buffer_t *"
<< print_name(arg.name)
<< "_buffer";
} else {
// If this arg is the user_context value, *and* ucon_type is valid,
// use ucon_type instead of arg.type.
const Type &t = (arg.name == "__user_context" && ucon_type.bits() != 0) ? ucon_type : arg.type;
stream << print_type(t, AppendSpace) << print_name(arg.name);
}
comma = ", ";
}
};

// Emit the function prototype
if (f.linkage == LinkageType::Internal) {
// If the function isn't public, mark it static.
stream << "static ";
}
stream << "HALIDE_FUNCTION_ATTRS\n";
stream << "int " << simple_name << "(";
for (size_t i = 0; i < args.size(); i++) {
if (args[i].is_buffer()) {
stream << "struct halide_buffer_t *"
<< print_name(args[i].name)
<< "_buffer";
} else {
stream << print_type(args[i].type, AppendSpace)
<< print_name(args[i].name);
}

if (i < args.size() - 1) {
stream << ", ";
}
}
emit_arg_decls();

if (is_header_or_extern_decl()) {
stream << ");\n";
Expand Down Expand Up @@ -1995,6 +2000,59 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na
close_scope("");
}

// Workaround for https://github.com/halide/Halide/issues/635:
// For historical reasons, Halide-generated AOT code
// defines user_context as `void const*`, but expects all
// define_extern code with user_context usage to use `void *`. This
// usually isn't an issue, but if both the caller and callee of the
// pass a user_context, *and* c_plus_plus_name_mangling is enabled,
// we get link errors because of this dichotomy. Fixing this
// "correctly" (ie so that everything always uses identical types for
// user_context in all cases) will require a *lot* of downstream
// churn (see https://github.com/halide/Halide/issues/7298),
// so this is a workaround: Add a wrapper with `void*`
// ucon -> `void const*` ucon. In most cases this will be ignored
// (and probably dead-stripped), but in these cases it's critical.
//
// (Note that we don't check to see if c_plus_plus_name_mangling is
// enabled, since that would have to be done on the caller side, and
// this is purely a callee-side fix.)
if (f.linkage != LinkageType::Internal &&
output_kind == CPlusPlusImplementation &&
target.has_feature(Target::CPlusPlusMangling) &&
get_target().has_feature(Target::UserContext)) {

Type ucon_type = Type();
for (const auto &arg : args) {
if (arg.name == "__user_context") {
ucon_type = arg.type;
break;
}
}
if (ucon_type == type_of<void const *>()) {
stream << "\nHALIDE_FUNCTION_ATTRS\n";
stream << "int " << simple_name << "(";
emit_arg_decls(type_of<void *>());
stream << ") ";
open_scope();
stream << get_indent() << " return " << simple_name << "(";
const char *comma = "";
for (const auto &arg : args) {
if (arg.name == "__user_context") {
// Add an explicit cast here so we won't call ourselves into oblivion
stream << "(void const *)";
}
stream << comma << print_name(arg.name);
if (arg.is_buffer()) {
stream << "_buffer";
}
comma = ", ";
}
stream << ");\n";
close_scope("");
}
}

if (f.linkage == LinkageType::ExternalPlusArgv || f.linkage == LinkageType::ExternalPlusMetadata) {
// Emit the argv version
emit_argv_wrapper(simple_name, args);
Expand Down
59 changes: 59 additions & 0 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,65 @@ std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
names.simple_name, f.args, input.get_metadata_name_map());
}
}

// Workaround for https://github.com/halide/Halide/issues/635:
// For historical reasons, Halide-generated AOT code
// defines user_context as `void const*`, but expects all
// define_extern code with user_context usage to use `void *`. This
// usually isn't an issue, but if both the caller and callee of the
// pass a user_context, *and* c_plus_plus_name_mangling is enabled,
// we get link errors because of this dichotomy. Fixing this
// "correctly" (ie so that everything always uses identical types for
// user_context in all cases) will require a *lot* of downstream
// churn (see https://github.com/halide/Halide/issues/7298),
// so this is a workaround: Add a wrapper with `void*`
// ucon -> `void const*` ucon. In most cases this will be ignored
// (and probably dead-stripped), but in these cases it's critical.
//
// (Note that we don't check to see if c_plus_plus_name_mangling is
// enabled, since that would have to be done on the caller side, and
// this is purely a callee-side fix.)
if (f.linkage != LinkageType::Internal &&
target.has_feature(Target::CPlusPlusMangling) &&
target.has_feature(Target::UserContext)) {

int wrapper_ucon_index = -1;
auto wrapper_args = f.args; // make a copy
auto wrapper_llvm_arg_types = arg_types; // make a copy
for (int i = 0; i < (int)wrapper_args.size(); i++) {
if (wrapper_args[i].name == "__user_context" && wrapper_args[i].type == type_of<void const *>()) {
// Update the type of the user_context argument to be void* rather than void const*
wrapper_args[i].type = type_of<void *>();
wrapper_llvm_arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(wrapper_args[i].type));
wrapper_ucon_index = i;
}
}
if (wrapper_ucon_index >= 0) {
const auto wrapper_names = get_mangled_names(f.name, f.linkage, f.name_mangling, wrapper_args, target);

FunctionType *wrapper_func_t = FunctionType::get(i32_t, wrapper_llvm_arg_types, false);
llvm::Function *wrapper_func = llvm::Function::Create(wrapper_func_t,
llvm::GlobalValue::ExternalLinkage,
wrapper_names.extern_name,
module.get());
set_function_attributes_from_halide_target_options(*wrapper_func);
llvm::BasicBlock *wrapper_block = llvm::BasicBlock::Create(module->getContext(), "entry", wrapper_func);
builder->SetInsertPoint(wrapper_block);

std::vector<llvm::Value *> wrapper_call_args;
for (auto &arg : wrapper_func->args()) {
wrapper_call_args.push_back(&arg);
}
wrapper_call_args[wrapper_ucon_index] = builder->CreatePointerCast(wrapper_call_args[wrapper_ucon_index],
llvm_type_of(type_of<void const *>()));

llvm::CallInst *wrapper_result = builder->CreateCall(function, wrapper_call_args);
// This call should never inline
wrapper_result->setIsNoInline();
builder->CreateRet(wrapper_result);
internal_assert(!verifyFunction(*wrapper_func, &llvm::errs()));
}
}
}
// Define all functions
int idx = 0;
Expand Down
7 changes: 2 additions & 5 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@ ostream &operator<<(ostream &out, const Type &type) {
out << "float";
break;
case Type::Handle:
if (type.handle_type) {
out << "(" << type.handle_type->inner_name.name << " *)";
} else {
out << "(void *)";
}
// ensure that 'const' (etc) qualifiers are emitted when appropriate
out << "(" << type_to_c_type(type, false) << ")";
break;
case Type::BFloat:
out << "bfloat";
Expand Down
2 changes: 1 addition & 1 deletion test/generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ _add_halide_aot_tests(multitarget
add_halide_generator(nested_externs.generator SOURCES nested_externs_generator.cpp)
set(NESTED_EXTERNS_LIBS nested_externs_root nested_externs_inner nested_externs_combine nested_externs_leaf)
foreach (LIB IN LISTS NESTED_EXTERNS_LIBS)
_add_halide_libraries(${LIB} FROM nested_externs.generator GENERATOR_NAME ${LIB} FEATURES c_plus_plus_name_mangling)
_add_halide_libraries(${LIB} FROM nested_externs.generator GENERATOR_NAME ${LIB} FEATURES user_context c_plus_plus_name_mangling)
endforeach ()
_add_halide_aot_tests(nested_externs
HALIDE_LIBRARIES ${NESTED_EXTERNS_LIBS})
Expand Down
3 changes: 2 additions & 1 deletion test/generator/nested_externs_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ int main(int argc, char **argv) {
auto val = Buffer<float, 0>::make_scalar();
val() = 38.5f;

nested_externs_root(val, buf);
void const *ucon = nullptr;
nested_externs_root(ucon, val, buf);

buf.for_each_element([&](int x, int y, int c) {
const float correct = 158.0f;
Expand Down
31 changes: 19 additions & 12 deletions test/generator/nested_externs_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NestedExternsCombine : public Generator<NestedExternsCombine> {
Output<Buffer<>> combine{"combine"}; // unspecified type-and-dim will be inferred

void generate() {
Var x, y, c;
Var x{"x"}, y{"y"}, c{"c"};
combine(x, y, c) = input_a(x, y, c) + input_b(x, y, c);
}

Expand All @@ -35,10 +35,11 @@ class NestedExternsInner : public Generator<NestedExternsInner> {
Output<Buffer<float, 3>> inner{"inner"};

void generate() {
extern_stage_1.define_extern("nested_externs_leaf", {value}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_leaf", {value + 1}, Float(32), 3);
Expr ucon = user_context_value();
extern_stage_1.define_extern("nested_externs_leaf", {ucon, value}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_leaf", {ucon, value + 1}, Float(32), 3);
extern_stage_combine.define_extern("nested_externs_combine",
{extern_stage_1, extern_stage_2}, Float(32), 3);
{ucon, extern_stage_1, extern_stage_2}, Float(32), 3);
inner(x, y, c) = extern_stage_combine(x, y, c);
}

Expand All @@ -51,8 +52,10 @@ class NestedExternsInner : public Generator<NestedExternsInner> {
}

private:
Var x, y, c;
Func extern_stage_1, extern_stage_2, extern_stage_combine;
Var x{"x"}, y{"y"}, c{"c"};
Func extern_stage_1{"extern_stage_1_inner"},
extern_stage_2{"extern_stage_2_inner"},
extern_stage_combine{"extern_stage_combine_inner"};
};

// Basically a memset.
Expand All @@ -62,7 +65,7 @@ class NestedExternsLeaf : public Generator<NestedExternsLeaf> {
Output<Buffer<float, 3>> leaf{"leaf"};

void generate() {
Var x, y, c;
Var x{"x"}, y{"y"}, c{"c"};
leaf(x, y, c) = value;
}

Expand All @@ -80,10 +83,11 @@ class NestedExternsRoot : public Generator<NestedExternsRoot> {
Output<Buffer<float, 3>> root{"root"};

void generate() {
extern_stage_1.define_extern("nested_externs_inner", {value()}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_inner", {value() + 1}, Float(32), 3);
Expr ucon = user_context_value();
extern_stage_1.define_extern("nested_externs_inner", {ucon, value()}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_inner", {ucon, value() + 1}, Float(32), 3);
extern_stage_combine.define_extern("nested_externs_combine",
{extern_stage_1, extern_stage_2}, Float(32), 3);
{ucon, extern_stage_1, extern_stage_2}, Float(32), 3);
root(x, y, c) = extern_stage_combine(x, y, c);
}

Expand All @@ -94,11 +98,14 @@ class NestedExternsRoot : public Generator<NestedExternsRoot> {
}
set_interleaved(root);
root.reorder_storage(c, x, y);
root.parallel(y, 8);
}

private:
Var x, y, c;
Func extern_stage_1, extern_stage_2, extern_stage_combine;
Var x{"x"}, y{"y"}, c{"c"};
Func extern_stage_1{"extern_stage_1_root"},
extern_stage_2{"extern_stage_2_root"},
extern_stage_combine{"extern_stage_combine_root"};
};

} // namespace
Expand Down

0 comments on commit c601e4e

Please sign in to comment.