Skip to content

Commit

Permalink
WIP: allow ccall library name to be non-constant
Browse files Browse the repository at this point in the history
Fixes #36458
  • Loading branch information
JeffBezanson committed Aug 19, 2020
1 parent 89d6b46 commit a6d0bf7
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 19 deletions.
101 changes: 82 additions & 19 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ static bool runtime_sym_gvs(jl_codegen_params_t &emission_context, const char *f
static Value *runtime_sym_lookup(
jl_codegen_params_t &emission_context,
IRBuilder<> &irbuilder,
PointerType *funcptype, const char *f_lib,
jl_codectx_t *ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f,
GlobalVariable *libptrgv,
GlobalVariable *llvmgv, bool runtime_lib)
Expand Down Expand Up @@ -106,16 +107,25 @@ static Value *runtime_sym_lookup(
assert(f->getParent() != NULL);
f->getBasicBlockList().push_back(dlsym_lookup);
irbuilder.SetInsertPoint(dlsym_lookup);
Value *libname;
if (runtime_lib) {
libname = stringConstPtr(emission_context, irbuilder, f_lib);
Value *llvmf;
Value *nameval = stringConstPtr(emission_context, irbuilder, f_name);
if (lib_expr) {
jl_cgval_t libval = emit_expr(*ctx, lib_expr);
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jllazydlsym_func),
{ boxed(*ctx, libval), nameval });
}
else {
// f_lib is actually one of the special sentinel values
libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8);
Value *libname;
if (runtime_lib) {
libname = stringConstPtr(emission_context, irbuilder, f_lib);
}
else {
// f_lib is actually one of the special sentinel values
libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8);
}
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
{ libname, nameval, libptrgv });
}
Value *llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
{ libname, stringConstPtr(emission_context, irbuilder, f_name), libptrgv });
StoreInst *store = irbuilder.CreateAlignedStore(llvmf, llvmgv, sizeof(void*));
store->setAtomic(AtomicOrdering::Release);
irbuilder.CreateBr(ccall_bb);
Expand All @@ -130,15 +140,43 @@ static Value *runtime_sym_lookup(

static Value *runtime_sym_lookup(
jl_codectx_t &ctx,
PointerType *funcptype, const char *f_lib,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f,
GlobalVariable *libptrgv,
GlobalVariable *llvmgv, bool runtime_lib)
{
return runtime_sym_lookup(ctx.emission_context, ctx.builder, &ctx, funcptype, f_lib, lib_expr,
f_name, f, libptrgv, llvmgv, runtime_lib);
}

static Value *runtime_sym_lookup(
jl_codectx_t &ctx,
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
const char *f_name, Function *f)
{
GlobalVariable *libptrgv;
GlobalVariable *llvmgv;
bool runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv);
libptrgv = prepare_global_in(jl_Module, libptrgv);
bool runtime_lib;
if (lib_expr) {
// for computed library names, generate a global variable to cache the function
// pointer just for this call site.
runtime_lib = true;
libptrgv = NULL;
std::string gvname = "libname_";
gvname += f_name;
gvname += "_";
gvname += std::to_string(globalUnique++);
Module *M = ctx.emission_context.shared_module(jl_LLVMContext);
llvmgv = new GlobalVariable(*M, T_pvoidfunc, false,
GlobalVariable::ExternalLinkage,
Constant::getNullValue(T_pvoidfunc), gvname);
}
else {
runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv);
libptrgv = prepare_global_in(jl_Module, libptrgv);
}
llvmgv = prepare_global_in(jl_Module, llvmgv);
return runtime_sym_lookup(ctx.emission_context, ctx.builder, funcptype, f_lib, f_name, f, libptrgv, llvmgv, runtime_lib);
return runtime_sym_lookup(ctx, funcptype, f_lib, lib_expr, f_name, f, libptrgv, llvmgv, runtime_lib);
}

// Emit a "PLT" entry that will be lazily initialized
Expand Down Expand Up @@ -169,7 +207,7 @@ static GlobalVariable *emit_plt_thunk(
fname);
BasicBlock *b0 = BasicBlock::Create(jl_LLVMContext, "top", plt);
IRBuilder<> irbuilder(b0);
Value *ptr = runtime_sym_lookup(emission_context, irbuilder, funcptype, f_lib, f_name, plt, libptrgv,
Value *ptr = runtime_sym_lookup(emission_context, irbuilder, NULL, funcptype, f_lib, NULL, f_name, plt, libptrgv,
llvmgv, runtime_lib);
StoreInst *store = irbuilder.CreateAlignedStore(irbuilder.CreateBitCast(ptr, T_pvoidfunc), got, sizeof(void*));
store->setAtomic(AtomicOrdering::Release);
Expand Down Expand Up @@ -475,6 +513,7 @@ typedef struct {
void (*fptr)(void); // if the argument is a constant pointer
const char *f_name; // if the symbol name is known
const char *f_lib; // if a library name is specified
jl_value_t *lib_expr; // expression to compute library path lazily
jl_value_t *gcroot;
} native_sym_arg_t;

Expand All @@ -488,6 +527,24 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va

jl_value_t *ptr = static_eval(ctx, arg, true);
if (ptr == NULL) {
if (jl_is_expr(arg) && ((jl_expr_t*)arg)->head == call_sym && jl_expr_nargs(arg) == 3 &&
jl_is_globalref(jl_exprarg(arg,0)) && jl_globalref_mod(jl_exprarg(arg,0)) == jl_core_module &&
jl_globalref_name(jl_exprarg(arg,0)) == jl_symbol("tuple")) {
// attempt to interpret a non-constant 2-tuple expression as (func_name, lib_name()), where
// `lib_name()` will be executed when first used.
jl_value_t *name_val = static_eval(ctx, jl_exprarg(arg,1));
if (name_val && jl_is_symbol(name_val)) {
f_name = jl_symbol_name((jl_sym_t*)name_val);
out.lib_expr = jl_exprarg(arg, 2);
return;
}
else if (name_val && jl_is_string(name_val)) {
f_name = jl_string_data(name_val);
out.gcroot = name_val;
out.lib_expr = jl_exprarg(arg, 2);
return;
}
}
jl_cgval_t arg1 = emit_expr(ctx, arg);
jl_value_t *ptr_ty = arg1.typ;
if (!jl_is_cpointer_type(ptr_ty)) {
Expand Down Expand Up @@ -586,8 +643,11 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
jl_printf(JL_STDERR,"WARNING: literal address used in cglobal for %s; code cannot be statically compiled\n", sym.f_name);
}
else {
if (imaging_mode) {
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, sym.f_name, ctx.f);
if (sym.lib_expr) {
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), NULL, sym.lib_expr, sym.f_name, ctx.f);
}
else if (imaging_mode) {
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f);
res = ctx.builder.CreatePtrToInt(res, lrt);
}
else {
Expand All @@ -597,7 +657,7 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
if (!libsym || !jl_dlsym(libsym, sym.f_name, &symaddr, 0)) {
// Error mode, either the library or the symbol couldn't be find during compiletime.
// Fallback to a runtime symbol lookup.
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, sym.f_name, ctx.f);
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f);
res = ctx.builder.CreatePtrToInt(res, lrt);
} else {
// since we aren't saving this code, there's no sense in
Expand Down Expand Up @@ -1716,11 +1776,14 @@ jl_cgval_t function_sig_t::emit_a_ccall(
else {
assert(symarg.f_name != NULL);
PointerType *funcptype = PointerType::get(functype, 0);
if (imaging_mode) {
if (symarg.lib_expr) {
llvmf = runtime_sym_lookup(ctx, funcptype, NULL, symarg.lib_expr, symarg.f_name, ctx.f);
}
else if (imaging_mode) {
// vararg requires musttail,
// but musttail is incompatible with noreturn.
if (functype->isVarArg())
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
else
llvmf = emit_plt(ctx, functype, attributes, cc, symarg.f_lib, symarg.f_name);
}
Expand All @@ -1730,7 +1793,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
if (!libsym || !jl_dlsym(libsym, symarg.f_name, &symaddr, 0)) {
// either the library or the symbol could not be found, place a runtime
// lookup here instead.
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f);
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
} else {
// since we aren't saving this code, there's no sense in
// putting anything complicated here: just JIT the function address
Expand Down
6 changes: 6 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,12 @@ static const auto jldlsym_func = new JuliaFunction{
{T_pint8, T_pint8, PointerType::get(T_pint8, 0)}, false); },
nullptr,
};
static const auto jllazydlsym_func = new JuliaFunction{
"jl_lazy_load_and_lookup",
[](LLVMContext &C) { return FunctionType::get(T_pvoidfunc,
{T_prjlvalue, T_pint8}, false); },
nullptr,
};
static const auto jltypeassert_func = new JuliaFunction{
"jl_typeassert",
[](LLVMContext &C) { return FunctionType::get(T_void,
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ void *jl_get_library_(const char *f_lib, int throw_err);
#define jl_get_library(f_lib) jl_get_library_(f_lib, 1)
JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name,
void **hnd);
JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name);
JL_DLLEXPORT jl_value_t *jl_get_cfunction_trampoline(
jl_value_t *fobj, jl_datatype_t *result, htable_t *cache, jl_svec_t *fill,
void *(*init_trampoline)(void *tramp, void **nval),
Expand Down
17 changes: 17 additions & 0 deletions src/runtime_ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ void *jl_load_and_lookup(const char *f_lib, const char *f_name, void **hnd)
return ptr;
}

// jl_load_and_lookup, but with library computed at run time on first call
extern "C" JL_DLLEXPORT
void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name)
{
char *f_lib;

if (jl_is_symbol(lib_val))
f_lib = jl_symbol_name((jl_sym_t*)lib_val);
else if (jl_is_string(lib_val))
f_lib = jl_string_data(lib_val);
else
jl_type_error("ccall", (jl_value_t*)jl_symbol_type, lib_val);
void *ptr;
jl_dlsym(jl_get_library(f_lib), f_name, &ptr, 1);
return ptr;
}

// miscellany
std::string jl_get_cpu_name_llvm(void)
{
Expand Down
6 changes: 6 additions & 0 deletions test/ccall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1701,3 +1701,9 @@ end
str = GC.@preserve buffer unsafe_string(Cwstring(pointer(buffer)))
@test str == "α+β=15"
end

# issue #36458
compute_lib_name() = "libcc" * "alltest"
ccall_lazy_lib_name(x) = ccall((:testUcharX, compute_lib_name()), Int32, (UInt8,), x % UInt8)
@test ccall_lazy_lib_name(0) == 0
@test ccall_lazy_lib_name(3) == 1

0 comments on commit a6d0bf7

Please sign in to comment.