Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workaround for the const-or-not user_context issue (#635) #7291

Merged
merged 7 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ $(FILTERS_DIR)/autograd_grad.a: $(BIN_DIR)/autograd.generator $(BIN_MULLAPUDI201
# all have the form nested_externs_*).
$(FILTERS_DIR)/nested_externs_%.a: $(BIN_DIR)/nested_externs.generator
@mkdir -p $(@D)
$(CURDIR)/$< -g nested_externs_$* $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-c_plus_plus_name_mangling
$(CURDIR)/$< -g nested_externs_$* $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-user_context-c_plus_plus_name_mangling

# Similarly, gpu_multi needs two different kernels to test compilation caching.
# Also requies user-context.
Expand Down
86 changes: 72 additions & 14 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,27 +1937,32 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na
}

if (output_kind != CPlusPlusFunctionInfoHeader) {
const auto emit_arg_decls = [&](const Type &ucon_type = Type()) {
const char *comma = "";
for (const auto &arg : args) {
stream << comma;
if (arg.is_buffer()) {
stream << "struct halide_buffer_t *"
<< print_name(arg.name)
<< "_buffer";
} else {
// If this arg is the user_context value, *and* ucon_type is valid,
// use ucon_type instead of arg.type.
const Type &t = (arg.name == "__user_context" && ucon_type.bits() != 0) ? ucon_type : arg.type;
stream << print_type(t, AppendSpace) << print_name(arg.name);
}
comma = ", ";
}
};

// Emit the function prototype
if (f.linkage == LinkageType::Internal) {
// If the function isn't public, mark it static.
stream << "static ";
}
stream << "HALIDE_FUNCTION_ATTRS\n";
stream << "int " << simple_name << "(";
for (size_t i = 0; i < args.size(); i++) {
if (args[i].is_buffer()) {
stream << "struct halide_buffer_t *"
<< print_name(args[i].name)
<< "_buffer";
} else {
stream << print_type(args[i].type, AppendSpace)
<< print_name(args[i].name);
}

if (i < args.size() - 1) {
stream << ", ";
}
}
emit_arg_decls();

if (is_header_or_extern_decl()) {
stream << ");\n";
Expand Down Expand Up @@ -1995,6 +2000,59 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na
close_scope("");
}

// Workaround for https://github.com/halide/Halide/issues/635:
// For historical reasons, Halide-generated AOT code
// defines user_context as `void const*`, but expects all
// define_extern code with user_context usage to use `void *`. This
// usually isn't an issue, but if both the caller and callee of the
// pass a user_context, *and* c_plus_plus_name_mangling is enabled,
// we get link errors because of this dichotomy. Fixing this
// "correctly" (ie so that everything always uses identical types for
// user_context in all cases) will require a *lot* of downstream
// churn (see https://github.com/halide/Halide/issues/7298),
// so this is a workaround: Add a wrapper with `void*`
// ucon -> `void const*` ucon. In most cases this will be ignored
// (and probably dead-stripped), but in these cases it's critical.
//
// (Note that we don't check to see if c_plus_plus_name_mangling is
// enabled, since that would have to be done on the caller side, and
// this is purely a callee-side fix.)
if (f.linkage != LinkageType::Internal &&
output_kind == CPlusPlusImplementation &&
target.has_feature(Target::CPlusPlusMangling) &&
get_target().has_feature(Target::UserContext)) {

Type ucon_type = Type();
for (const auto &arg : args) {
if (arg.name == "__user_context") {
ucon_type = arg.type;
break;
}
}
if (ucon_type == type_of<void const *>()) {
stream << "\nHALIDE_FUNCTION_ATTRS\n";
stream << "int " << simple_name << "(";
emit_arg_decls(type_of<void *>());
stream << ") ";
open_scope();
stream << get_indent() << " return " << simple_name << "(";
const char *comma = "";
for (const auto &arg : args) {
if (arg.name == "__user_context") {
// Add an explicit cast here so we won't call ourselves into oblivion
stream << "(void const *)";
}
stream << comma << print_name(arg.name);
if (arg.is_buffer()) {
stream << "_buffer";
}
comma = ", ";
}
stream << ");\n";
close_scope("");
}
}

if (f.linkage == LinkageType::ExternalPlusArgv || f.linkage == LinkageType::ExternalPlusMetadata) {
// Emit the argv version
emit_argv_wrapper(simple_name, args);
Expand Down
59 changes: 59 additions & 0 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,65 @@ std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
names.simple_name, f.args, input.get_metadata_name_map());
}
}

// Workaround for https://github.com/halide/Halide/issues/635:
// For historical reasons, Halide-generated AOT code
// defines user_context as `void const*`, but expects all
// define_extern code with user_context usage to use `void *`. This
// usually isn't an issue, but if both the caller and callee of the
// pass a user_context, *and* c_plus_plus_name_mangling is enabled,
// we get link errors because of this dichotomy. Fixing this
// "correctly" (ie so that everything always uses identical types for
// user_context in all cases) will require a *lot* of downstream
// churn (see https://github.com/halide/Halide/issues/7298),
// so this is a workaround: Add a wrapper with `void*`
// ucon -> `void const*` ucon. In most cases this will be ignored
// (and probably dead-stripped), but in these cases it's critical.
//
// (Note that we don't check to see if c_plus_plus_name_mangling is
// enabled, since that would have to be done on the caller side, and
// this is purely a callee-side fix.)
if (f.linkage != LinkageType::Internal &&
target.has_feature(Target::CPlusPlusMangling) &&
target.has_feature(Target::UserContext)) {

int wrapper_ucon_index = -1;
auto wrapper_args = f.args; // make a copy
auto wrapper_llvm_arg_types = arg_types; // make a copy
for (int i = 0; i < (int)wrapper_args.size(); i++) {
if (wrapper_args[i].name == "__user_context" && wrapper_args[i].type == type_of<void const *>()) {
// Update the type of the user_context argument to be void* rather than void const*
wrapper_args[i].type = type_of<void *>();
wrapper_llvm_arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(wrapper_args[i].type));
wrapper_ucon_index = i;
}
}
if (wrapper_ucon_index >= 0) {
const auto wrapper_names = get_mangled_names(f.name, f.linkage, f.name_mangling, wrapper_args, target);

FunctionType *wrapper_func_t = FunctionType::get(i32_t, wrapper_llvm_arg_types, false);
llvm::Function *wrapper_func = llvm::Function::Create(wrapper_func_t,
llvm::GlobalValue::ExternalLinkage,
wrapper_names.extern_name,
module.get());
set_function_attributes_from_halide_target_options(*wrapper_func);
llvm::BasicBlock *wrapper_block = llvm::BasicBlock::Create(module->getContext(), "entry", wrapper_func);
builder->SetInsertPoint(wrapper_block);

std::vector<llvm::Value *> wrapper_call_args;
for (auto &arg : wrapper_func->args()) {
wrapper_call_args.push_back(&arg);
}
wrapper_call_args[wrapper_ucon_index] = builder->CreatePointerCast(wrapper_call_args[wrapper_ucon_index],
llvm_type_of(type_of<void const *>()));

llvm::CallInst *wrapper_result = builder->CreateCall(function, wrapper_call_args);
// This call should never inline
wrapper_result->setIsNoInline();
builder->CreateRet(wrapper_result);
internal_assert(!verifyFunction(*wrapper_func, &llvm::errs()));
}
}
}
// Define all functions
int idx = 0;
Expand Down
7 changes: 2 additions & 5 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@ ostream &operator<<(ostream &out, const Type &type) {
out << "float";
break;
case Type::Handle:
if (type.handle_type) {
out << "(" << type.handle_type->inner_name.name << " *)";
} else {
out << "(void *)";
}
// ensure that 'const' (etc) qualifiers are emitted when appropriate
out << "(" << type_to_c_type(type, false) << ")";
break;
case Type::BFloat:
out << "bfloat";
Expand Down
2 changes: 1 addition & 1 deletion test/generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ _add_halide_aot_tests(multitarget
add_halide_generator(nested_externs.generator SOURCES nested_externs_generator.cpp)
set(NESTED_EXTERNS_LIBS nested_externs_root nested_externs_inner nested_externs_combine nested_externs_leaf)
foreach (LIB IN LISTS NESTED_EXTERNS_LIBS)
_add_halide_libraries(${LIB} FROM nested_externs.generator GENERATOR_NAME ${LIB} FEATURES c_plus_plus_name_mangling)
_add_halide_libraries(${LIB} FROM nested_externs.generator GENERATOR_NAME ${LIB} FEATURES user_context c_plus_plus_name_mangling)
endforeach ()
_add_halide_aot_tests(nested_externs
HALIDE_LIBRARIES ${NESTED_EXTERNS_LIBS})
Expand Down
3 changes: 2 additions & 1 deletion test/generator/nested_externs_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ int main(int argc, char **argv) {
auto val = Buffer<float, 0>::make_scalar();
val() = 38.5f;

nested_externs_root(val, buf);
void const *ucon = nullptr;
nested_externs_root(ucon, val, buf);

buf.for_each_element([&](int x, int y, int c) {
const float correct = 158.0f;
Expand Down
31 changes: 19 additions & 12 deletions test/generator/nested_externs_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NestedExternsCombine : public Generator<NestedExternsCombine> {
Output<Buffer<>> combine{"combine"}; // unspecified type-and-dim will be inferred

void generate() {
Var x, y, c;
Var x{"x"}, y{"y"}, c{"c"};
combine(x, y, c) = input_a(x, y, c) + input_b(x, y, c);
}

Expand All @@ -35,10 +35,11 @@ class NestedExternsInner : public Generator<NestedExternsInner> {
Output<Buffer<float, 3>> inner{"inner"};

void generate() {
extern_stage_1.define_extern("nested_externs_leaf", {value}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_leaf", {value + 1}, Float(32), 3);
Expr ucon = user_context_value();
extern_stage_1.define_extern("nested_externs_leaf", {ucon, value}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_leaf", {ucon, value + 1}, Float(32), 3);
extern_stage_combine.define_extern("nested_externs_combine",
{extern_stage_1, extern_stage_2}, Float(32), 3);
{ucon, extern_stage_1, extern_stage_2}, Float(32), 3);
inner(x, y, c) = extern_stage_combine(x, y, c);
}

Expand All @@ -51,8 +52,10 @@ class NestedExternsInner : public Generator<NestedExternsInner> {
}

private:
Var x, y, c;
Func extern_stage_1, extern_stage_2, extern_stage_combine;
Var x{"x"}, y{"y"}, c{"c"};
Func extern_stage_1{"extern_stage_1_inner"},
extern_stage_2{"extern_stage_2_inner"},
extern_stage_combine{"extern_stage_combine_inner"};
};

// Basically a memset.
Expand All @@ -62,7 +65,7 @@ class NestedExternsLeaf : public Generator<NestedExternsLeaf> {
Output<Buffer<float, 3>> leaf{"leaf"};

void generate() {
Var x, y, c;
Var x{"x"}, y{"y"}, c{"c"};
leaf(x, y, c) = value;
}

Expand All @@ -80,10 +83,11 @@ class NestedExternsRoot : public Generator<NestedExternsRoot> {
Output<Buffer<float, 3>> root{"root"};

void generate() {
extern_stage_1.define_extern("nested_externs_inner", {value()}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_inner", {value() + 1}, Float(32), 3);
Expr ucon = user_context_value();
extern_stage_1.define_extern("nested_externs_inner", {ucon, value()}, Float(32), 3);
extern_stage_2.define_extern("nested_externs_inner", {ucon, value() + 1}, Float(32), 3);
extern_stage_combine.define_extern("nested_externs_combine",
{extern_stage_1, extern_stage_2}, Float(32), 3);
{ucon, extern_stage_1, extern_stage_2}, Float(32), 3);
root(x, y, c) = extern_stage_combine(x, y, c);
}

Expand All @@ -94,11 +98,14 @@ class NestedExternsRoot : public Generator<NestedExternsRoot> {
}
set_interleaved(root);
root.reorder_storage(c, x, y);
root.parallel(y, 8);
}

private:
Var x, y, c;
Func extern_stage_1, extern_stage_2, extern_stage_combine;
Var x{"x"}, y{"y"}, c{"c"};
Func extern_stage_1{"extern_stage_1_root"},
extern_stage_2{"extern_stage_2_root"},
extern_stage_combine{"extern_stage_combine_root"};
};

} // namespace
Expand Down