Skip to content

Commit

Permalink
Split codegen support in storage and arithmetic.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 6, 2023
1 parent afd9f40 commit 4f3861c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/bfloat16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,32 @@ import Base: isfinite, isnan, precision, iszero, eps,
asinh, acosh, atanh, acsch, asech, acoth,
bitstring, isinteger

# Julia 1.11 provides codegen support for BFloat16
const codegen_support = if isdefined(Core, :BFloat16) &&
Sys.ARCH in [:x86_64, :i686]
# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating
# code that uses the `bfloat` IR type, together with the necessary runtime functions.
# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s
# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall
# back to defining a primitive type that will be represented as an `i16`. If, in addition,
# the target supports BFloat16 arithmetic, we can use LLVM intrinsics.
# - x86: storage and arithmetic support in LLVM 15
# - aarch64: storage support in LLVM 17
const llvm_storage = if isdefined(Core, :BFloat16)
if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15"
true
elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17"
true
else
false
end
else
false
end
const llvm_arithmetic = if llvm_storage
using Core: BFloat16
true
if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15"
true
else
false
end
else
primitive type BFloat16 <: AbstractFloat 16 end
false
Expand Down Expand Up @@ -76,7 +97,7 @@ precision(::Type{BFloat16}) = 8
eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00)

## Rounding ##
if codegen_support
if llvm_arithmetic
round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x)
round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x)
round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x)
Expand Down Expand Up @@ -118,7 +139,7 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16,
(reinterpret(UInt32, x) >> 16) % UInt16
)

if codegen_support
if llvm_arithmetic
BFloat16(x::Float32) = Base.fptrunc(BFloat16, x)
BFloat16(x::Float64) = Base.fptrunc(BFloat16, x)

Check warning on line 144 in src/bfloat16.jl

View check run for this annotation

Codecov / codecov/patch

src/bfloat16.jl#L143-L144

Added lines #L143 - L144 were not covered by tests

Expand Down Expand Up @@ -147,7 +168,7 @@ else
end

# Conversion from Integer
if codegen_support
if llvm_arithmetic
for st in (Int8, Int16, Int32, Int64)
@eval begin
BFloat16(x::($st)) = Base.sitofp(BFloat16, x)

Check warning on line 174 in src/bfloat16.jl

View check run for this annotation

Codecov / codecov/patch

src/bfloat16.jl#L174

Added line #L174 was not covered by tests
Expand All @@ -170,7 +191,7 @@ function Base.Float16(x::BFloat16)
Float16(Float32(x))
end

if codegen_support
if llvm_arithmetic
Base.Float32(x::BFloat16) = Base.fpext(Float32, x)
Base.Float64(x::BFloat16) = Base.fpext(Float64, x)

Check warning on line 196 in src/bfloat16.jl

View check run for this annotation

Codecov / codecov/patch

src/bfloat16.jl#L195-L196

Added lines #L195 - L196 were not covered by tests
else
Expand All @@ -186,7 +207,7 @@ else
end

# Basic arithmetic
if codegen_support
if llvm_arithmetic
+(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y)
-(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y)
*(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y)
Expand Down Expand Up @@ -238,7 +259,7 @@ end
Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y)

Check warning on line 259 in src/bfloat16.jl

View check run for this annotation

Codecov / codecov/patch

src/bfloat16.jl#L259

Added line #L259 was not covered by tests

# Truncation to integer types
if codegen_support
if llvm_arithmetic
for Ti in (Int8, Int16, Int32, Int64)
@eval begin
Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x)

Check warning on line 265 in src/bfloat16.jl

View check run for this annotation

Codecov / codecov/patch

src/bfloat16.jl#L265

Added line #L265 was not covered by tests
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Test, BFloat16s, Printf, Random

@info "Testing BFloat16s" BFloat16s.llvm_storage BFloat16s.llvm_arithmetic

@testset "comparisons" begin
@test BFloat16(1) < BFloat16(2)
@test BFloat16(1f0) < BFloat16(2f0)
Expand Down

0 comments on commit 4f3861c

Please sign in to comment.