Skip to content

Commit

Permalink
Adapt to upstream changes wrt. native support for BFloat16 (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Nov 2, 2023
1 parent a42c4fa commit 730511b
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 57 deletions.
245 changes: 190 additions & 55 deletions src/bfloat16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Base: isfinite, isnan, precision, iszero, eps,
typemin, typemax, floatmin, floatmax,
sign_mask, exponent_mask, significand_mask,
exponent_bits, significand_bits, exponent_bias,
exponent_one, exponent_half,
exponent_one, exponent_half, leading_zeros,
signbit, exponent, significand, frexp, ldexp,
round, Int16, Int32, Int64,
+, -, *, /, ^, ==, <, <=, >=, >, !=, inv,
Expand All @@ -13,33 +13,66 @@ import Base: isfinite, isnan, precision, iszero, eps,
asin, acos, atan, acsc, asec, acot,
sinh, cosh, tanh, csch, sech, coth,
asinh, acosh, atanh, acsch, asech, acoth,
bitstring

primitive type BFloat16 <: AbstractFloat 16 end
bitstring, isinteger

import Printf

# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating
# code that uses the `bfloat` IR type, together with the necessary runtime functions.
# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s
# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall
# back to defining a primitive type that will be represented as an `i16`. If, in addition,
# the target supports BFloat16 arithmetic, we can use LLVM intrinsics.
# - x86: storage and arithmetic support in LLVM 15
# - aarch64: storage support in LLVM 17
const llvm_storage = if isdefined(Core, :BFloat16)
if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15"
true
elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17"
true
else
false
end
else
false
end
const llvm_arithmetic = if llvm_storage
using Core: BFloat16
if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15"
true
else
false
end
else
primitive type BFloat16 <: AbstractFloat 16 end
false
end

Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x)
Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x)

# Floating point property queries
for f in (:sign_mask, :exponent_mask, :exponent_one,
:exponent_half, :significand_mask)
:exponent_half, :significand_mask)
@eval $(f)(::Type{BFloat16}) = UInt16($(f)(Float32) >> 16)
end

Base.exponent_bias(::Type{BFloat16}) = 127
Base.exponent_bits(::Type{BFloat16}) = 8
Base.significand_bits(::Type{BFloat16}) = 7
Base.signbit(x::BFloat16) = (reinterpret(Unsigned, x) & 0x8000) !== 0x0000

function Base.significand(x::BFloat16)
result = abs_significand(x)
ifelse(signbit(x), -result, result)
end

@inline function abs_significand(x::BFloat16)
usig = Base.significand_mask(BFloat16) & reinterpret(Unsigned, x)
isig = Int16(usig)
1 + isig / BFloat16(2)^7
xu = reinterpret(Unsigned, x)
xs = xu & ~sign_mask(BFloat16)
xs >= exponent_mask(BFloat16) && return x # NaN or Inf
if xs <= (~exponent_mask(BFloat16) & ~sign_mask(BFloat16)) # x is subnormal
xs == 0 && return x # +-0
m = unsigned(leading_zeros(xs) - exponent_bits(BFloat16))
xs <<= m
xu = xs | (xu & sign_mask(BFloat16))
end
xu = (xu & ~exponent_mask(BFloat16)) | exponent_one(BFloat16)
return reinterpret(BFloat16, xu)
end

Base.exponent(x::BFloat16) =
Expand All @@ -65,16 +98,24 @@ isnan(x::BFloat16) = (reinterpret(Unsigned,x) & ~sign_mask(BFloat16)) > exponent
precision(::Type{BFloat16}) = 8
eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00)

round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x)))
## Rounding ##
if llvm_arithmetic
round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x)
round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x)
round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x)
round(x::BFloat16, ::RoundingMode{:Nearest}) = Base.rint_llvm(x)
else
round(x::BFloat16, r::RoundingMode{:ToZero}) = BFloat16(trunc(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x)))
end
# round(::Type{Signed}, x::BFloat16, r::RoundingMode) = round(Int, x, r)
# round(::Type{Unsigned}, x::BFloat16, r::RoundingMode) = round(UInt, x, r)
# round(::Type{Integer}, x::BFloat16, r::RoundingMode) = round(Int, x, r)

Base.trunc(bf::BFloat16) = signbit(bf) ? ceil(bf) : floor(bf)

Int64(x::BFloat16) = Int64(Float32(x))
Int32(x::BFloat16) = Int32(Float32(x))
Int16(x::BFloat16) = Int16(Float32(x))

## floating point traits ##
"""
InfB16
Expand All @@ -100,56 +141,88 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16,
(reinterpret(UInt32, x) >> 16) % UInt16
)

# Conversion from Float32
function BFloat16(x::Float32)
isnan(x) && return NaNB16
# Round to nearest even (matches TensorFlow and our convention for
# rounding to lower precision floating point types).
h = reinterpret(UInt32, x)
h += 0x7fff + ((h >> 16) & 1)
return reinterpret(BFloat16, (h >> 16) % UInt16)
end
if llvm_arithmetic
BFloat16(x::Float32) = Base.fptrunc(BFloat16, x)
BFloat16(x::Float64) = Base.fptrunc(BFloat16, x)

# XXX: can LLVM do this natively?
BFloat16(x::Float16) = BFloat16(Float32(x))
else
# Conversion from Float32
function BFloat16(x::Float32)
isnan(x) && return NaNB16
# Round to nearest even (matches TensorFlow and our convention for
# rounding to lower precision floating point types).
h = reinterpret(UInt32, x)
h += 0x7fff + ((h >> 16) & 1)
return reinterpret(BFloat16, (h >> 16) % UInt16)
end

# Conversion from Float64
function BFloat16(x::Float64)
BFloat16(Float32(x))
end
# Conversion from Float64
function BFloat16(x::Float64)
BFloat16(Float32(x))
end

# Conversion from Float16
function BFloat16(x::Float16)
BFloat16(Float32(x))
# Conversion from Float16
function BFloat16(x::Float16)
BFloat16(Float32(x))
end
end

# Conversion from Integer
function BFloat16(x::Integer)
convert(BFloat16, convert(Float32, x))
if llvm_arithmetic
for st in (Int8, Int16, Int32, Int64)
@eval begin
BFloat16(x::($st)) = Base.sitofp(BFloat16, x)
end
end
for ut in (Bool, UInt8, UInt16, UInt32, UInt64)
@eval begin
BFloat16(x::($ut)) = Base.uitofp(BFloat16, x)
end
end
else
BFloat16(x::Integer) = convert(BFloat16, convert(Float32, x))
end
# TODO: optimize
BFloat16(x::UInt128) = convert(BFloat16, Float64(x))
BFloat16(x::Int128) = convert(BFloat16, Float64(x))

# Conversion to Float16
function Base.Float16(x::BFloat16)
Float16(Float32(x))
end

# Expansion to Float32
function Base.Float32(x::BFloat16)
reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16)
end
if llvm_arithmetic
Base.Float32(x::BFloat16) = Base.fpext(Float32, x)
Base.Float64(x::BFloat16) = Base.fpext(Float64, x)
else
# Expansion to Float32
function Base.Float32(x::BFloat16)
reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16)
end

# Expansion to Float64
function Base.Float64(x::BFloat16)
Float64(Float32(x))
# Expansion to Float64
function Base.Float64(x::BFloat16)
Float64(Float32(x))
end
end

# Truncation to integer types
Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x))
Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x))

# Basic arithmetic
for f in (:+, :-, :*, :/, :^)
@eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y)))
if llvm_arithmetic
+(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y)
-(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y)
*(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y)
/(x::T, y::T) where {T<:BFloat16} = Base.div_float(x, y)
-(x::BFloat16) = Base.neg_float(x)
^(x::BFloat16, y::BFloat16) = BFloat16(Float32(x)^Float32(y))
else
for f in (:+, :-, :*, :/, :^)
@eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y)))
end
-(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) sign_mask(BFloat16))
end
-(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) sign_mask(BFloat16))
^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y))
^(x::BFloat16, y::Integer) = BFloat16(Float32(x)^y)

const ZeroBFloat16 = BFloat16(0.0f0)
const OneBFloat16 = BFloat16(1.0f0)
Expand Down Expand Up @@ -185,7 +258,68 @@ for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt
end

# Wide multiplication
Base.widemul(x::BFloat16, y::BFloat16) = Float32(x) * Float32(y)
Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y)

# Truncation to integer types
if llvm_arithmetic
for Ti in (Int8, Int16, Int32, Int64)
@eval begin
Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x)
end
end
for Ti in (UInt8, UInt16, UInt32, UInt64)
@eval begin
Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptoui($Ti, x)
end
end
else
Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x))
end
for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128)
if Ti <: Unsigned || sizeof(Ti) < 2
# Here `BFloat16(typemin(Ti))-1` is exact, so we can compare the lower-bound
# directly. `BFloat16(typemax(Ti))+1` is either always exactly representable, or
# rounded to `Inf` (e.g. when `Ti==UInt128 && BFloat16==Float32`).
@eval begin
function Base.trunc(::Type{$Ti}, x::BFloat16)
if $(BFloat16(typemin(Ti))-one(BFloat16)) < x < $(BFloat16(typemax(Ti))+one(BFloat16))
return Base.unsafe_trunc($Ti,x)
else
throw(InexactError(:trunc, $Ti, x))
end
end
function (::Type{$Ti})(x::BFloat16)
if ($(BFloat16(typemin(Ti))) <= x <= $(BFloat16(typemax(Ti)))) && isinteger(x)
return Base.unsafe_trunc($Ti,x)
else
throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x))
end
end
end
else
# Here `eps(BFloat16(typemin(Ti))) > 1`, so the only value which can be
# truncated to `BFloat16(typemin(Ti)` is itself. Similarly,
# `BFloat16(typemax(Ti))` is inexact and will be rounded up. This assumes that
# `BFloat16(typemin(Ti)) > -Inf`, which is true for these types, but not for
# `Float16` or larger integer types.
@eval begin
function Base.trunc(::Type{$Ti}, x::BFloat16)
if $(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))
return unsafe_trunc($Ti,x)
else
throw(InexactError(:trunc, $Ti, x))
end
end
function (::Type{$Ti})(x::BFloat16)
if ($(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))) && isinteger(x)
return unsafe_trunc($Ti,x)
else
throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x))
end
end
end
end
end

# Showing
function Base.show(io::IO, x::BFloat16)
Expand All @@ -200,6 +334,7 @@ function Base.show(io::IO, x::BFloat16)
hastypeinfo || print(io, ")")
end
end
Printf.tofloat(x::BFloat16) = Float32(x)

# Random
import Random: rand, randn, randexp, AbstractRNG, Sampler
Expand Down
14 changes: 12 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Test, BFloat16s, Printf, Random

@info "Testing BFloat16s" BFloat16s.llvm_storage BFloat16s.llvm_arithmetic

@testset "comparisons" begin
@test BFloat16(1) < BFloat16(2)
@test BFloat16(1f0) < BFloat16(2f0)
Expand Down Expand Up @@ -27,6 +29,14 @@ end
@test Int64(BFloat16(10)) == Int64(10)
end

@testset "abi" begin
f() = BFloat16(1)
@test f() == BFloat16(1)

g(x) = x+BFloat16(1)
@test g(BFloat16(2)) == BFloat16(3)
end

@testset "functions" begin
@test abs(BFloat16(-10)) == BFloat16(10)
@test BFloat16(2) ^ BFloat16(4) == BFloat16(16)
Expand Down Expand Up @@ -58,7 +68,7 @@ end
("%.2a", "0x1.3cp+0"),
("%.2A", "0X1.3CP+0")),
num in (BFloat16(1.234),)
@test @eval(@sprintf($fmt, $num) == $val)
@eval @test @sprintf($fmt, $num) == $val
end
@test (@sprintf "%f" BFloat16(Inf)) == "Inf"
@test (@sprintf "%f" BFloat16(NaN)) == "NaN"
Expand All @@ -73,7 +83,7 @@ end
("%-+10.5g", "+123.5 "),
("%010.5g", "00000123.5")),
num in (BFloat16(123.5),)
@test @eval(@sprintf($fmt, $num) == $val)
@eval @test @sprintf($fmt, $num) == $val
end
@test( @sprintf( "%10.5g", BFloat16(-123.5) ) == " -123.5")
@test( @sprintf( "%010.5g", BFloat16(-123.5) ) == "-0000123.5")
Expand Down

0 comments on commit 730511b

Please sign in to comment.