Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to upstream changes wrt. native support for BFloat16 #51

Merged
merged 7 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 190 additions & 55 deletions src/bfloat16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Base: isfinite, isnan, precision, iszero, eps,
typemin, typemax, floatmin, floatmax,
sign_mask, exponent_mask, significand_mask,
exponent_bits, significand_bits, exponent_bias,
exponent_one, exponent_half,
exponent_one, exponent_half, leading_zeros,
signbit, exponent, significand, frexp, ldexp,
round, Int16, Int32, Int64,
+, -, *, /, ^, ==, <, <=, >=, >, !=, inv,
Expand All @@ -13,33 +13,66 @@ import Base: isfinite, isnan, precision, iszero, eps,
asin, acos, atan, acsc, asec, acot,
sinh, cosh, tanh, csch, sech, coth,
asinh, acosh, atanh, acsch, asech, acoth,
bitstring

primitive type BFloat16 <: AbstractFloat 16 end
bitstring, isinteger

import Printf

# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating
# code that uses the `bfloat` IR type, together with the necessary runtime functions.
# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s
# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall
# back to defining a primitive type that will be represented as an `i16`. If, in addition,
# the target supports BFloat16 arithmetic, we can use LLVM intrinsics.
# - x86: storage and arithmetic support in LLVM 15
# - aarch64: storage support in LLVM 17
const llvm_storage = if isdefined(Core, :BFloat16)
if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15"
true
elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17"
true
else
false
end
else
false
end
const llvm_arithmetic = if llvm_storage
using Core: BFloat16
if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15"
true
else
false
end
else
primitive type BFloat16 <: AbstractFloat 16 end
false
end

Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x)
Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x)

# Floating point property queries
for f in (:sign_mask, :exponent_mask, :exponent_one,
:exponent_half, :significand_mask)
:exponent_half, :significand_mask)
@eval $(f)(::Type{BFloat16}) = UInt16($(f)(Float32) >> 16)
end

Base.exponent_bias(::Type{BFloat16}) = 127
Base.exponent_bits(::Type{BFloat16}) = 8
Base.significand_bits(::Type{BFloat16}) = 7
Base.signbit(x::BFloat16) = (reinterpret(Unsigned, x) & 0x8000) !== 0x0000

function Base.significand(x::BFloat16)
result = abs_significand(x)
ifelse(signbit(x), -result, result)
end

@inline function abs_significand(x::BFloat16)
usig = Base.significand_mask(BFloat16) & reinterpret(Unsigned, x)
isig = Int16(usig)
1 + isig / BFloat16(2)^7
xu = reinterpret(Unsigned, x)
xs = xu & ~sign_mask(BFloat16)
xs >= exponent_mask(BFloat16) && return x # NaN or Inf
if xs <= (~exponent_mask(BFloat16) & ~sign_mask(BFloat16)) # x is subnormal
xs == 0 && return x # +-0
m = unsigned(leading_zeros(xs) - exponent_bits(BFloat16))
xs <<= m
xu = xs | (xu & sign_mask(BFloat16))
end
xu = (xu & ~exponent_mask(BFloat16)) | exponent_one(BFloat16)
return reinterpret(BFloat16, xu)
end

Base.exponent(x::BFloat16) =
Expand All @@ -65,16 +98,24 @@ isnan(x::BFloat16) = (reinterpret(Unsigned,x) & ~sign_mask(BFloat16)) > exponent
precision(::Type{BFloat16}) = 8
eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00)

round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x)))
## Rounding ##
if llvm_arithmetic
round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x)
round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x)
round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x)
round(x::BFloat16, ::RoundingMode{:Nearest}) = Base.rint_llvm(x)
else
round(x::BFloat16, r::RoundingMode{:ToZero}) = BFloat16(trunc(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x)))
round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x)))
end
# round(::Type{Signed}, x::BFloat16, r::RoundingMode) = round(Int, x, r)
# round(::Type{Unsigned}, x::BFloat16, r::RoundingMode) = round(UInt, x, r)
# round(::Type{Integer}, x::BFloat16, r::RoundingMode) = round(Int, x, r)

Base.trunc(bf::BFloat16) = signbit(bf) ? ceil(bf) : floor(bf)

Int64(x::BFloat16) = Int64(Float32(x))
Int32(x::BFloat16) = Int32(Float32(x))
Int16(x::BFloat16) = Int16(Float32(x))

## floating point traits ##
"""
InfB16
Expand All @@ -100,56 +141,88 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16,
(reinterpret(UInt32, x) >> 16) % UInt16
)

# Conversion from Float32
function BFloat16(x::Float32)
isnan(x) && return NaNB16
# Round to nearest even (matches TensorFlow and our convention for
# rounding to lower precision floating point types).
h = reinterpret(UInt32, x)
h += 0x7fff + ((h >> 16) & 1)
return reinterpret(BFloat16, (h >> 16) % UInt16)
end
if llvm_arithmetic
BFloat16(x::Float32) = Base.fptrunc(BFloat16, x)
BFloat16(x::Float64) = Base.fptrunc(BFloat16, x)

# XXX: can LLVM do this natively?
BFloat16(x::Float16) = BFloat16(Float32(x))
else
# Conversion from Float32
function BFloat16(x::Float32)
isnan(x) && return NaNB16
# Round to nearest even (matches TensorFlow and our convention for
# rounding to lower precision floating point types).
h = reinterpret(UInt32, x)
h += 0x7fff + ((h >> 16) & 1)
return reinterpret(BFloat16, (h >> 16) % UInt16)
end

# Conversion from Float64
function BFloat16(x::Float64)
BFloat16(Float32(x))
end
# Conversion from Float64
function BFloat16(x::Float64)
BFloat16(Float32(x))
end

# Conversion from Float16
function BFloat16(x::Float16)
BFloat16(Float32(x))
# Conversion from Float16
function BFloat16(x::Float16)
BFloat16(Float32(x))
end
end

# Conversion from Integer
function BFloat16(x::Integer)
convert(BFloat16, convert(Float32, x))
if llvm_arithmetic
for st in (Int8, Int16, Int32, Int64)
@eval begin
BFloat16(x::($st)) = Base.sitofp(BFloat16, x)
end
end
for ut in (Bool, UInt8, UInt16, UInt32, UInt64)
@eval begin
BFloat16(x::($ut)) = Base.uitofp(BFloat16, x)
end
end
else
BFloat16(x::Integer) = convert(BFloat16, convert(Float32, x))
end
# TODO: optimize
BFloat16(x::UInt128) = convert(BFloat16, Float64(x))
BFloat16(x::Int128) = convert(BFloat16, Float64(x))

# Conversion to Float16
function Base.Float16(x::BFloat16)
Float16(Float32(x))
end

# Expansion to Float32
function Base.Float32(x::BFloat16)
reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16)
end
if llvm_arithmetic
Base.Float32(x::BFloat16) = Base.fpext(Float32, x)
Base.Float64(x::BFloat16) = Base.fpext(Float64, x)
else
# Expansion to Float32
function Base.Float32(x::BFloat16)
reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16)
end

# Expansion to Float64
function Base.Float64(x::BFloat16)
Float64(Float32(x))
# Expansion to Float64
function Base.Float64(x::BFloat16)
Float64(Float32(x))
end
end

# Truncation to integer types
Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x))
Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x))

# Basic arithmetic
for f in (:+, :-, :*, :/, :^)
@eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y)))
if llvm_arithmetic
+(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y)
-(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y)
*(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y)
/(x::T, y::T) where {T<:BFloat16} = Base.div_float(x, y)
-(x::BFloat16) = Base.neg_float(x)
^(x::BFloat16, y::BFloat16) = BFloat16(Float32(x)^Float32(y))
else
for f in (:+, :-, :*, :/, :^)
@eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y)))
end
-(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16))
end
-(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16))
^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y))
^(x::BFloat16, y::Integer) = BFloat16(Float32(x)^y)

const ZeroBFloat16 = BFloat16(0.0f0)
const OneBFloat16 = BFloat16(1.0f0)
Expand Down Expand Up @@ -185,7 +258,68 @@ for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt
end

# Wide multiplication
Base.widemul(x::BFloat16, y::BFloat16) = Float32(x) * Float32(y)
Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y)

# Truncation to integer types
if llvm_arithmetic
for Ti in (Int8, Int16, Int32, Int64)
@eval begin
Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x)
end
end
for Ti in (UInt8, UInt16, UInt32, UInt64)
@eval begin
Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptoui($Ti, x)
end
end
else
Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x))
end
for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128)
if Ti <: Unsigned || sizeof(Ti) < 2
# Here `BFloat16(typemin(Ti))-1` is exact, so we can compare the lower-bound
# directly. `BFloat16(typemax(Ti))+1` is either always exactly representable, or
# rounded to `Inf` (e.g. when `Ti==UInt128 && BFloat16==Float32`).
@eval begin
function Base.trunc(::Type{$Ti}, x::BFloat16)
if $(BFloat16(typemin(Ti))-one(BFloat16)) < x < $(BFloat16(typemax(Ti))+one(BFloat16))
return Base.unsafe_trunc($Ti,x)
else
throw(InexactError(:trunc, $Ti, x))
end
end
function (::Type{$Ti})(x::BFloat16)
if ($(BFloat16(typemin(Ti))) <= x <= $(BFloat16(typemax(Ti)))) && isinteger(x)
return Base.unsafe_trunc($Ti,x)
else
throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x))
end
end
end
else
# Here `eps(BFloat16(typemin(Ti))) > 1`, so the only value which can be
# truncated to `BFloat16(typemin(Ti)` is itself. Similarly,
# `BFloat16(typemax(Ti))` is inexact and will be rounded up. This assumes that
# `BFloat16(typemin(Ti)) > -Inf`, which is true for these types, but not for
# `Float16` or larger integer types.
@eval begin
function Base.trunc(::Type{$Ti}, x::BFloat16)
if $(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))
return unsafe_trunc($Ti,x)
else
throw(InexactError(:trunc, $Ti, x))
end
end
function (::Type{$Ti})(x::BFloat16)
if ($(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))) && isinteger(x)
return unsafe_trunc($Ti,x)
else
throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x))
end
end
end
end
end

# Showing
function Base.show(io::IO, x::BFloat16)
Expand All @@ -200,6 +334,7 @@ function Base.show(io::IO, x::BFloat16)
hastypeinfo || print(io, ")")
end
end
Printf.tofloat(x::BFloat16) = Float32(x)

# Random
import Random: rand, randn, randexp, AbstractRNG, Sampler
Expand Down
14 changes: 12 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Test, BFloat16s, Printf, Random

@info "Testing BFloat16s" BFloat16s.llvm_storage BFloat16s.llvm_arithmetic

@testset "comparisons" begin
@test BFloat16(1) < BFloat16(2)
@test BFloat16(1f0) < BFloat16(2f0)
Expand Down Expand Up @@ -27,6 +29,14 @@ end
@test Int64(BFloat16(10)) == Int64(10)
end

@testset "abi" begin
f() = BFloat16(1)
@test f() == BFloat16(1)

g(x) = x+BFloat16(1)
@test g(BFloat16(2)) == BFloat16(3)
end

@testset "functions" begin
@test abs(BFloat16(-10)) == BFloat16(10)
@test BFloat16(2) ^ BFloat16(4) == BFloat16(16)
Expand Down Expand Up @@ -58,7 +68,7 @@ end
("%.2a", "0x1.3cp+0"),
("%.2A", "0X1.3CP+0")),
num in (BFloat16(1.234),)
@test @eval(@sprintf($fmt, $num) == $val)
@eval @test @sprintf($fmt, $num) == $val
end
@test (@sprintf "%f" BFloat16(Inf)) == "Inf"
@test (@sprintf "%f" BFloat16(NaN)) == "NaN"
Expand All @@ -73,7 +83,7 @@ end
("%-+10.5g", "+123.5 "),
("%010.5g", "00000123.5")),
num in (BFloat16(123.5),)
@test @eval(@sprintf($fmt, $num) == $val)
@eval @test @sprintf($fmt, $num) == $val
end
@test( @sprintf( "%10.5g", BFloat16(-123.5) ) == " -123.5")
@test( @sprintf( "%010.5g", BFloat16(-123.5) ) == "-0000123.5")
Expand Down
Loading