Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emit aliases to FP16 conversion routines #45649

Merged
merged 3 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
36 changes: 36 additions & 0 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <llvm/Analysis/BasicAliasAnalysis.h>
#include <llvm/Analysis/TypeBasedAliasAnalysis.h>
#include <llvm/Analysis/ScopedNoAliasAA.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Transforms/IPO.h>
Expand All @@ -32,6 +33,7 @@
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar/InstSimplifyPass.h>
#include <llvm/Transforms/Utils/SimplifyCFGOptions.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Passes/PassPlugin.h>
#if defined(USE_POLLY)
Expand Down Expand Up @@ -434,6 +436,23 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
if (!target) {
target = Function::Create(FT, Function::ExternalLinkage, alias, M);
}
Function *interposer = Function::Create(FT, Function::WeakAnyLinkage, name, M);
appendToCompilerUsed(M, {interposer});

llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", interposer));
SmallVector<Value *, 4> CallArgs;
for (auto &arg : interposer->args())
CallArgs.push_back(&arg);
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}


// takes the running content that has collected in the shadow module and dump it to disk
// this builds the object file portion of the sysimage files for fast startup
Expand Down Expand Up @@ -550,7 +569,24 @@ void jl_dump_native_impl(void *native_code,
auto add_output = [&] (Module &M, StringRef unopt_bc_Name, StringRef bc_Name, StringRef obj_Name, StringRef asm_Name) {
preopt.run(M);
optimizer.run(M);

// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime
// functions. We do so after optimization to avoid cloning these functions.
injectCRTAlias(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
injectCRTAlias(M, "__extendhfsf2", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
injectCRTAlias(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee",
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(M, "__truncsfhf2", "julia__gnu_f2h_ieee",
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false));

postopt.run(M);

if (unopt_bc_fname)
emit_result(unopt_bc_Archive, unopt_bc_Buffer, unopt_bc_Name, outputs);
if (bc_fname)
Expand Down
18 changes: 16 additions & 2 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,26 @@ JuliaOJIT::JuliaOJIT()
}

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

orc::SymbolAliasMap jl_crt = {
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
}

void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
orc::SymbolStringPtr JuliaOJIT::mangle(StringRef Name)
{
std::string MangleName = getMangledName(Name);
cantFail(JD.define(orc::absoluteSymbols({{ES.intern(MangleName), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
return ES.intern(MangleName);
}

void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
{
cantFail(JD.define(orc::absoluteSymbols({{mangle(Name), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
}

void JuliaOJIT::addModule(orc::ThreadSafeModule TSM)
Expand Down
1 change: 1 addition & 0 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ class JuliaOJIT {
void RegisterJITEventListener(JITEventListener *L);
#endif

orc::SymbolStringPtr mangle(StringRef Name);
void addGlobalMapping(StringRef Name, uint64_t Addr);
void addModule(orc::ThreadSafeModule M);

Expand Down
6 changes: 0 additions & 6 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
14 changes: 12 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1543,8 +1543,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
64 changes: 37 additions & 27 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
const unsigned int host_char_bit = 8;

// float16 intrinsics
// TODO: use LLVM's compiler-rt on all platforms (Xcode already links compiler-rt)

#if !defined(_OS_DARWIN_)

static inline float half_to_float(uint16_t ival) JL_NOTSAFEPOINT
{
Expand Down Expand Up @@ -188,22 +185,17 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT
return h;
}

JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param)
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param)
{
return half_to_float(param);
}

JL_DLLEXPORT float __extendhfsf2(uint16_t param)
{
return half_to_float(param);
}

JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param)
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param)
{
return float_to_half(param);
}

JL_DLLEXPORT uint16_t __truncdfhf2(double param)
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
{
float res = (float)param;
uint32_t resi;
Expand All @@ -225,7 +217,25 @@ JL_DLLEXPORT uint16_t __truncdfhf2(double param)
return float_to_half(res);
}

#endif
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) { return (uint32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) { return (uint64_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) { return julia__gnu_f2h_ieee((float)n); }
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) { return julia__gnu_f2h_ieee((float)n); }
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) { return julia__gnu_f2h_ieee((float)n); }
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) { return julia__gnu_f2h_ieee((float)n); }
//HANDLE_LIBCALL(F16, F128, __extendhftf2)
//HANDLE_LIBCALL(F16, F80, __extendhfxf2)
//HANDLE_LIBCALL(F80, F16, __truncxfhf2)
//HANDLE_LIBCALL(F128, F16, __trunctfhf2)
//HANDLE_LIBCALL(PPCF128, F16, __trunctfhf2)
//HANDLE_LIBCALL(F16, I128, __fixhfti)
//HANDLE_LIBCALL(F16, I128, __fixunshfti)
//HANDLE_LIBCALL(I128, F16, __floattihf)
//HANDLE_LIBCALL(I128, F16, __floatuntihf)


// run time version of bitcast intrinsic
JL_DLLEXPORT jl_value_t *jl_bitcast(jl_value_t *ty, jl_value_t *v)
Expand Down Expand Up @@ -551,9 +561,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT
}

#define fp_select(a, func) \
sizeof(a) == sizeof(float) ? func##f((float)a) : func(a)
sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a)
#define fp_select2(a, b, func) \
sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b)
sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b)

// fast-function generators //

Expand Down Expand Up @@ -597,11 +607,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = __gnu_h2f_ieee(a); \
float A = julia__gnu_h2f_ieee(a); \
if (osize == 16) { \
float R; \
OP(&R, A); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
} else { \
OP((uint16_t*)pr, A); \
} \
Expand All @@ -625,11 +635,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr)
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
}

// float or integer inputs, bool output
Expand All @@ -650,8 +660,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
runtime_nbits = 16; \
return OP(A, B); \
}
Expand Down Expand Up @@ -691,12 +701,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc,
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
uint16_t c = *(uint16_t*)pc; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float C = __gnu_h2f_ieee(c); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
float C = julia__gnu_h2f_ieee(c); \
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
}


Expand Down Expand Up @@ -1318,7 +1328,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \
fpiseq_n(float, 32)
fpiseq_n(double, 64)
#define fpiseq(a,b) \
sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)
sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)

bool_fintrinsic(eq,eq_float)
bool_fintrinsic(ne,ne_float)
Expand Down Expand Up @@ -1367,7 +1377,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui)
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
else if (osize == 16) \
*(uint16_t*)pr = __gnu_f2h_ieee(a); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(a); \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
Expand Down
24 changes: 24 additions & 0 deletions test/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,30 @@ end
@test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3)
end

if Sys.ARCH == :aarch64
# On AArch64 we are following the `_Float16` ABI. Buthe these functions expect `Int16`.
# TODO: SHould we have `Chalf == Int16` and `Cfloat16 == Float16`?
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
truncsfhf2(x::Float32) = reinterpret(Float16, ccall("extern __truncsfhf2", llvmcall, Int16, (Float32,), x))
gnu_f2h_ieee(x::Float32) = reinterpret(Float16, ccall("extern __gnu_f2h_ieee", llvmcall, Int16, (Float32,), x))
truncdfhf2(x::Float64) = reinterpret(Float16, ccall("extern __truncdfhf2", llvmcall, Int16, (Float64,), x))
else
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Float16,), x)
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Float16,), x)
truncsfhf2(x::Float32) = ccall("extern __truncsfhf2", llvmcall, Float16, (Float32,), x)
gnu_f2h_ieee(x::Float32) = ccall("extern __gnu_f2h_ieee", llvmcall, Float16, (Float32,), x)
truncdfhf2(x::Float64) = ccall("extern __truncdfhf2", llvmcall, Float16, (Float64,), x)
end

@testset "Float16 intrinsics (crt)" begin
@test extendhfsf2(Float16(3.3)) == 3.3007812f0
@test gnu_h2f_ieee(Float16(3.3)) == 3.3007812f0
@test truncsfhf2(3.3f0) == Float16(3.3)
@test gnu_f2h_ieee(3.3f0) == Float16(3.3)
@test truncdfhf2(3.3) == Float16(3.3)
end

using Base.Experimental: @force_compile
@test_throws ConcurrencyViolationError("invalid atomic ordering") (@force_compile; Core.Intrinsics.atomic_fence(:u)) === nothing
@test_throws ConcurrencyViolationError("invalid atomic ordering") (@force_compile; Core.Intrinsics.atomic_fence(Symbol("u", "x"))) === nothing
Expand Down