From 09f2bdb632d98210e1d5d384e7da48b9c475481f Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 4 Oct 2023 10:32:47 +0200 Subject: [PATCH 1/7] Use LLVM intrinsics when Julia supports BFloat16. --- src/bfloat16.jl | 199 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 153 insertions(+), 46 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 218d48a..d0b8a0a 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -13,19 +13,25 @@ import Base: isfinite, isnan, precision, iszero, eps, asin, acos, atan, acsc, asec, acot, sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, - bitstring - -primitive type BFloat16 <: AbstractFloat 16 end + bitstring, isinteger + +# Julia 1.11 provides codegen support for BFloat16 +if isdefined(Core, :BFloat16) + using Core: BFloat16 + const codegen_support = true +else + primitive type BFloat16 <: AbstractFloat 16 end + const codegen_support = false +end Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x) Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x) # Floating point property queries for f in (:sign_mask, :exponent_mask, :exponent_one, - :exponent_half, :significand_mask) + :exponent_half, :significand_mask) @eval $(f)(::Type{BFloat16}) = UInt16($(f)(Float32) >> 16) end - Base.exponent_bias(::Type{BFloat16}) = 127 Base.exponent_bits(::Type{BFloat16}) = 8 Base.significand_bits(::Type{BFloat16}) = 7 @@ -65,16 +71,24 @@ isnan(x::BFloat16) = (reinterpret(Unsigned,x) & ~sign_mask(BFloat16)) > exponent precision(::Type{BFloat16}) = 8 eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00) -round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x))) -round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x))) -round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x))) +## Rounding ## +if codegen_support + round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x) + round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x) + round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x) + round(x::BFloat16, ::RoundingMode{:Nearest}) = Base.rint_llvm(x) +else + round(x::BFloat16, r::RoundingMode{:ToZero}) = BFloat16(trunc(Float32(x))) + round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x))) + round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x))) + round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x))) +end +# round(::Type{Signed}, x::BFloat16, r::RoundingMode) = round(Int, x, r) +# round(::Type{Unsigned}, x::BFloat16, r::RoundingMode) = round(UInt, x, r) +# round(::Type{Integer}, x::BFloat16, r::RoundingMode) = round(Int, x, r) Base.trunc(bf::BFloat16) = signbit(bf) ? ceil(bf) : floor(bf) -Int64(x::BFloat16) = Int64(Float32(x)) -Int32(x::BFloat16) = Int32(Float32(x)) -Int16(x::BFloat16) = Int16(Float32(x)) - ## floating point traits ## """ InfB16 @@ -100,56 +114,88 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16, (reinterpret(UInt32, x) >> 16) % UInt16 ) -# Conversion from Float32 -function BFloat16(x::Float32) - isnan(x) && return NaNB16 - # Round to nearest even (matches TensorFlow and our convention for - # rounding to lower precision floating point types). - h = reinterpret(UInt32, x) - h += 0x7fff + ((h >> 16) & 1) - return reinterpret(BFloat16, (h >> 16) % UInt16) -end +if codegen_support + BFloat16(x::Float32) = Base.fptrunc(BFloat16, x) + BFloat16(x::Float64) = Base.fptrunc(BFloat16, x) + + # XXX: can LLVM do this natively? + BFloat16(x::Float16) = BFloat16(Float32(x)) +else + # Conversion from Float32 + function BFloat16(x::Float32) + isnan(x) && return NaNB16 + # Round to nearest even (matches TensorFlow and our convention for + # rounding to lower precision floating point types). + h = reinterpret(UInt32, x) + h += 0x7fff + ((h >> 16) & 1) + return reinterpret(BFloat16, (h >> 16) % UInt16) + end -# Conversion from Float64 -function BFloat16(x::Float64) - BFloat16(Float32(x)) -end + # Conversion from Float64 + function BFloat16(x::Float64) + BFloat16(Float32(x)) + end -# Conversion from Float16 -function BFloat16(x::Float16) - BFloat16(Float32(x)) + # Conversion from Float16 + function BFloat16(x::Float16) + BFloat16(Float32(x)) + end end # Conversion from Integer -function BFloat16(x::Integer) - convert(BFloat16, convert(Float32, x)) +if codegen_support + for st in (Int8, Int16, Int32, Int64) + @eval begin + BFloat16(x::($st)) = Base.sitofp(BFloat16, x) + end + end + for ut in (Bool, UInt8, UInt16, UInt32, UInt64) + @eval begin + BFloat16(x::($ut)) = Base.uitofp(BFloat16, x) + end + end +else + BFloat16(x::Integer) = convert(BFloat16, convert(Float32, x)) end +# TODO: optimize +BFloat16(x::UInt128) = convert(BFloat16, Float64(x)) +BFloat16(x::Int128) = convert(BFloat16, Float64(x)) # Conversion to Float16 function Base.Float16(x::BFloat16) Float16(Float32(x)) end -# Expansion to Float32 -function Base.Float32(x::BFloat16) - reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16) -end +if codegen_support + Base.Float32(x::BFloat16) = Base.fpext(Float32, x) + Base.Float64(x::BFloat16) = Base.fpext(Float64, x) +else + # Expansion to Float32 + function Base.Float32(x::BFloat16) + reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16) + end -# Expansion to Float64 -function Base.Float64(x::BFloat16) - Float64(Float32(x)) + # Expansion to Float64 + function Base.Float64(x::BFloat16) + Float64(Float32(x)) + end end -# Truncation to integer types -Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) -Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) - # Basic arithmetic -for f in (:+, :-, :*, :/, :^) - @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) +if codegen_support + +(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y) + -(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y) + *(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y) + /(x::T, y::T) where {T<:BFloat16} = Base.div_float(x, y) + -(x::BFloat16) = Base.neg_float(x) + ^(x::BFloat16, y::BFloat16) = BFloat16(Float32(x)^Float32(y)) +else + for f in (:+, :-, :*, :/, :^) + @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) + end + -(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) end --(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) -^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y)) +^(x::BFloat16, y::Integer) = BFloat16(Float32(x)^y) const ZeroBFloat16 = BFloat16(0.0f0) const OneBFloat16 = BFloat16(1.0f0) @@ -185,7 +231,68 @@ for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt end # Wide multiplication -Base.widemul(x::BFloat16, y::BFloat16) = Float32(x) * Float32(y) +Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y) + +# Truncation to integer types +if codegen_support + for Ti in (Int8, Int16, Int32, Int64) + @eval begin + Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x) + end + end + for Ti in (UInt8, UInt16, UInt32, UInt64) + @eval begin + Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptoui($Ti, x) + end + end +else + Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) +end +for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128) + if Ti <: Unsigned || sizeof(Ti) < 2 + # Here `BFloat16(typemin(Ti))-1` is exact, so we can compare the lower-bound + # directly. `BFloat16(typemax(Ti))+1` is either always exactly representable, or + # rounded to `Inf` (e.g. when `Ti==UInt128 && BFloat16==Float32`). + @eval begin + function Base.trunc(::Type{$Ti}, x::BFloat16) + if $(BFloat16(typemin(Ti))-one(BFloat16)) < x < $(BFloat16(typemax(Ti))+one(BFloat16)) + return Base.unsafe_trunc($Ti,x) + else + throw(InexactError(:trunc, $Ti, x)) + end + end + function (::Type{$Ti})(x::BFloat16) + if ($(BFloat16(typemin(Ti))) <= x <= $(BFloat16(typemax(Ti)))) && isinteger(x) + return Base.unsafe_trunc($Ti,x) + else + throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x)) + end + end + end + else + # Here `eps(BFloat16(typemin(Ti))) > 1`, so the only value which can be + # truncated to `BFloat16(typemin(Ti)` is itself. Similarly, + # `BFloat16(typemax(Ti))` is inexact and will be rounded up. This assumes that + # `BFloat16(typemin(Ti)) > -Inf`, which is true for these types, but not for + # `Float16` or larger integer types. + @eval begin + function Base.trunc(::Type{$Ti}, x::BFloat16) + if $(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti))) + return unsafe_trunc($Ti,x) + else + throw(InexactError(:trunc, $Ti, x)) + end + end + function (::Type{$Ti})(x::BFloat16) + if ($(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))) && isinteger(x) + return unsafe_trunc($Ti,x) + else + throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x)) + end + end + end + end +end # Showing function Base.show(io::IO, x::BFloat16) From 80a1280aa758261315b7a9b554810e512a6ef855 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 4 Oct 2023 10:33:11 +0200 Subject: [PATCH 2/7] Use Base's 'significand' implementation to avoid UB. --- src/bfloat16.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index d0b8a0a..308ed59 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -2,7 +2,7 @@ import Base: isfinite, isnan, precision, iszero, eps, typemin, typemax, floatmin, floatmax, sign_mask, exponent_mask, significand_mask, exponent_bits, significand_bits, exponent_bias, - exponent_one, exponent_half, + exponent_one, exponent_half, leading_zeros, signbit, exponent, significand, frexp, ldexp, round, Int16, Int32, Int64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, @@ -38,14 +38,17 @@ Base.significand_bits(::Type{BFloat16}) = 7 Base.signbit(x::BFloat16) = (reinterpret(Unsigned, x) & 0x8000) !== 0x0000 function Base.significand(x::BFloat16) - result = abs_significand(x) - ifelse(signbit(x), -result, result) -end - -@inline function abs_significand(x::BFloat16) - usig = Base.significand_mask(BFloat16) & reinterpret(Unsigned, x) - isig = Int16(usig) - 1 + isig / BFloat16(2)^7 + xu = reinterpret(Unsigned, x) + xs = xu & ~sign_mask(BFloat16) + xs >= exponent_mask(BFloat16) && return x # NaN or Inf + if xs <= (~exponent_mask(BFloat16) & ~sign_mask(BFloat16)) # x is subnormal + xs == 0 && return x # +-0 + m = unsigned(leading_zeros(xs) - exponent_bits(BFloat16)) + xs <<= m + xu = xs | (xu & sign_mask(BFloat16)) + end + xu = (xu & ~exponent_mask(BFloat16)) | exponent_one(BFloat16) + return reinterpret(BFloat16, xu) end Base.exponent(x::BFloat16) = From 22f2a04ab555ccb5257adbd529af18075263d1bf Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Oct 2023 09:58:48 +0200 Subject: [PATCH 3/7] Restrict use of codegen to x86 only. --- src/bfloat16.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 308ed59..03ac981 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -16,12 +16,13 @@ import Base: isfinite, isnan, precision, iszero, eps, bitstring, isinteger # Julia 1.11 provides codegen support for BFloat16 -if isdefined(Core, :BFloat16) +const codegen_support = if isdefined(Core, :BFloat16) && + Sys.ARCH in [:x86_64, :i686] using Core: BFloat16 - const codegen_support = true + true else primitive type BFloat16 <: AbstractFloat 16 end - const codegen_support = false + false end Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x) From 867d6f55a13e3c423ff55af4ebe9e0ed6b1851b5 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Oct 2023 10:34:38 +0200 Subject: [PATCH 4/7] Split codegen support in storage and arithmetic. --- src/bfloat16.jl | 41 +++++++++++++++++++++++++++++++---------- test/runtests.jl | 2 ++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 03ac981..1dff1da 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -15,11 +15,32 @@ import Base: isfinite, isnan, precision, iszero, eps, asinh, acosh, atanh, acsch, asech, acoth, bitstring, isinteger -# Julia 1.11 provides codegen support for BFloat16 -const codegen_support = if isdefined(Core, :BFloat16) && - Sys.ARCH in [:x86_64, :i686] +# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating +# code that uses the `bfloat` IR type, together with the necessary runtime functions. +# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s +# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall +# back to defining a primitive type that will be represented as an `i16`. If, in addition, +# the target supports BFloat16 arithmetic, we can use LLVM intrinsics. +# - x86: storage and arithmetic support in LLVM 15 +# - aarch64: storage support in LLVM 17 +const llvm_storage = if isdefined(Core, :BFloat16) + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17" + true + else + false + end +else + false +end +const llvm_arithmetic = if llvm_storage using Core: BFloat16 - true + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + else + false + end else primitive type BFloat16 <: AbstractFloat 16 end false @@ -76,7 +97,7 @@ precision(::Type{BFloat16}) = 8 eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00) ## Rounding ## -if codegen_support +if llvm_arithmetic round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x) round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x) round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x) @@ -118,7 +139,7 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16, (reinterpret(UInt32, x) >> 16) % UInt16 ) -if codegen_support +if llvm_arithmetic BFloat16(x::Float32) = Base.fptrunc(BFloat16, x) BFloat16(x::Float64) = Base.fptrunc(BFloat16, x) @@ -147,7 +168,7 @@ else end # Conversion from Integer -if codegen_support +if llvm_arithmetic for st in (Int8, Int16, Int32, Int64) @eval begin BFloat16(x::($st)) = Base.sitofp(BFloat16, x) @@ -170,7 +191,7 @@ function Base.Float16(x::BFloat16) Float16(Float32(x)) end -if codegen_support +if llvm_arithmetic Base.Float32(x::BFloat16) = Base.fpext(Float32, x) Base.Float64(x::BFloat16) = Base.fpext(Float64, x) else @@ -186,7 +207,7 @@ else end # Basic arithmetic -if codegen_support +if llvm_arithmetic +(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y) -(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y) *(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y) @@ -238,7 +259,7 @@ end Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y) # Truncation to integer types -if codegen_support +if llvm_arithmetic for Ti in (Int8, Int16, Int32, Int64) @eval begin Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x) diff --git a/test/runtests.jl b/test/runtests.jl index 44f4bcd..ab7a504 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using Test, BFloat16s, Printf, Random +@info "Testing BFloat16s" BFloat16s.llvm_storage BFloat16s.llvm_arithmetic + @testset "comparisons" begin @test BFloat16(1) < BFloat16(2) @test BFloat16(1f0) < BFloat16(2f0) From 3c6ed434804c23d6ec22edecbee433e0c293f026 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Oct 2023 14:21:44 +0200 Subject: [PATCH 5/7] Switch order of macros to improve error message. --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ab7a504..da0a3d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,7 +60,7 @@ end ("%.2a", "0x1.3cp+0"), ("%.2A", "0X1.3CP+0")), num in (BFloat16(1.234),) - @test @eval(@sprintf($fmt, $num) == $val) + @eval @test @sprintf($fmt, $num) == $val end @test (@sprintf "%f" BFloat16(Inf)) == "Inf" @test (@sprintf "%f" BFloat16(NaN)) == "NaN" @@ -75,7 +75,7 @@ end ("%-+10.5g", "+123.5 "), ("%010.5g", "00000123.5")), num in (BFloat16(123.5),) - @test @eval(@sprintf($fmt, $num) == $val) + @eval @test @sprintf($fmt, $num) == $val end @test( @sprintf( "%10.5g", BFloat16(-123.5) ) == " -123.5") @test( @sprintf( "%010.5g", BFloat16(-123.5) ) == "-0000123.5") From 569dcc58426c99bb519345b0abb1cc49da9198ec Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Oct 2023 14:40:08 +0200 Subject: [PATCH 6/7] Try using Float32 as printf conversion type. --- src/bfloat16.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 1dff1da..ebe02d6 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -15,6 +15,8 @@ import Base: isfinite, isnan, precision, iszero, eps, asinh, acosh, atanh, acsch, asech, acoth, bitstring, isinteger +import Printf + # LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating # code that uses the `bfloat` IR type, together with the necessary runtime functions. # However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s @@ -332,6 +334,7 @@ function Base.show(io::IO, x::BFloat16) hastypeinfo || print(io, ")") end end +Printf.tofloat(x::BFloat16) = Float32(x) # Random import Random: rand, randn, randexp, AbstractRNG, Sampler From ca9744233e82ca560782f5afb6205688df677a56 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 6 Oct 2023 14:42:52 +0200 Subject: [PATCH 7/7] Add ABI test. --- test/runtests.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index da0a3d6..8196e9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,6 +29,14 @@ end @test Int64(BFloat16(10)) == Int64(10) end +@testset "abi" begin + f() = BFloat16(1) + @test f() == BFloat16(1) + + g(x) = x+BFloat16(1) + @test g(BFloat16(2)) == BFloat16(3) +end + @testset "functions" begin @test abs(BFloat16(-10)) == BFloat16(10) @test BFloat16(2) ^ BFloat16(4) == BFloat16(16)