Skip to content

Commit

Permalink
Base: correctly rounded floats constructed from rationals
Browse files Browse the repository at this point in the history
Constructing a floating-point number from a `Rational` should now be
correctly rounded.

Implementation approach:

1. Convert the (numerator, denominator) pair to a (sign bit, integral
   significand, exponent) triplet using integer arithmetic. The integer
   type in question must be wide enough.

2. Convert the above triplet into an instance of the chosen FP type.
   There is special support for IEEE 754 floating-point and for
   `BigFloat`, otherwise a fallback using `ldexp` is used.

As a bonus, constructing a `BigFloat` from a `Rational` should now be
thread-safe when the rounding mode and precision are provided to the
constructor, because there is no access to the global precision or
rounding mode settings.

Updates #45213

Updates #50940

Updates #52507

Fixes #52394

Closes #52395

Fixes #52859
  • Loading branch information
nsajko committed Jan 15, 2024
1 parent fc6295d commit 314db19
Show file tree
Hide file tree
Showing 8 changed files with 773 additions and 11 deletions.
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ include("rounding.jl")
include("div.jl")
include("rawbigints.jl")
include("float.jl")
include("rational_to_float.jl")
include("twiceprecision.jl")
include("complex.jl")
include("rational.jl")
Expand Down
19 changes: 19 additions & 0 deletions base/docs/basedocs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2777,6 +2777,25 @@ julia> 4/2
julia> 4.5/2
2.25
```
This function may convert integer arguments to a floating-point number type
([`AbstractFloat`](@ref)), potentially resulting in a loss of accuracy. To avoid this,
instead construct a [`Rational`](@ref) from the arguments, then convert the resulting
rational number to a specific floating-point type of your choice:
```jldoctest
julia> n = 100000000000000000
100000000000000000
julia> m = n + 6
100000000000000006
julia> n/m
1.0
julia> Float64(n//m) # `//` constructs a `Rational`
0.9999999999999999
```
"""
/(x, y)

Expand Down
52 changes: 46 additions & 6 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import
isone, big, _string_n, decompose, minmax,
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask,
RawBigIntRoundingIncrementHelper, truncated, RawBigInt
RawBigIntRoundingIncrementHelper, truncated, RawBigInt, unsafe_rational,
RationalToFloat, rational_to_floating_point


using .Base.Libc
Expand Down Expand Up @@ -310,12 +311,51 @@ BigFloat(x::Union{UInt8,UInt16,UInt32}, r::MPFRRoundingMode=ROUNDING_MODE[]; pre
BigFloat(x::Union{Float16,Float32}, r::MPFRRoundingMode=ROUNDING_MODE[]; precision::Integer=DEFAULT_PRECISION[]) =
BigFloat(Float64(x), r; precision=precision)

function BigFloat(x::Rational, r::MPFRRoundingMode=ROUNDING_MODE[]; precision::Integer=DEFAULT_PRECISION[])
setprecision(BigFloat, precision) do
setrounding_raw(BigFloat, r) do
BigFloat(numerator(x))::BigFloat / BigFloat(denominator(x))::BigFloat
function set_2exp!(z::BigFloat, n::BigInt, exp::Int, rm::MPFRRoundingMode)
ccall(
(:mpfr_set_z_2exp, libmpfr),
Int32,
(Ref{BigFloat}, Ref{BigInt}, Int, MPFRRoundingMode),
z, n, exp, rm,
)
nothing
end

function RationalToFloat.to_floating_point_impl(::Type{BigFloat}, ::Type{BigInt}, num, den, romo, prec)
num_is_zero = iszero(num)
den_is_zero = iszero(den)
s = Int8(sign(num))
sb = signbit(s)
is_zero = num_is_zero & !den_is_zero
is_inf = !num_is_zero & den_is_zero
is_regular = !num_is_zero & !den_is_zero

if is_regular
let rtfc = RationalToFloat.to_float_components
c = rtfc(BigInt, num, den, prec, nothing, romo, sb)
ret = BigFloat(precision = prec)
mpfr_romo = convert(MPFRRoundingMode, romo)
set_2exp!(ret, s * c.integral_significand, Int(c.exponent - prec + true), mpfr_romo)
ret
end
end
else
if is_zero
BigFloat(false, MPFRRoundToZero, precision = prec)
elseif is_inf
BigFloat(s * Inf, MPFRRoundToZero, precision = prec)
else
BigFloat(precision = prec)
end
end::BigFloat
end

function BigFloat(x::Rational, r::RoundingMode; precision::Integer = DEFAULT_PRECISION[])
rational_to_floating_point(BigFloat, x, r, precision)
end

function BigFloat(x::Rational, r::MPFRRoundingMode = ROUNDING_MODE[];
precision::Integer = DEFAULT_PRECISION[])
rational_to_floating_point(BigFloat, x, r, precision)
end

function tryparse(::Type{BigFloat}, s::AbstractString; base::Integer=0, precision::Integer=DEFAULT_PRECISION[], rounding::MPFRRoundingMode=ROUNDING_MODE[])
Expand Down
21 changes: 17 additions & 4 deletions base/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,23 @@ Bool(x::Rational) = x==0 ? false : x==1 ? true :
(::Type{T})(x::Rational) where {T<:Integer} = (isinteger(x) ? convert(T, x.num)::T :
throw(InexactError(nameof(T), T, x)))

AbstractFloat(x::Rational) = (float(x.num)/float(x.den))::AbstractFloat
function (::Type{T})(x::Rational{S}) where T<:AbstractFloat where S
P = promote_type(T,S)
convert(T, convert(P,x.num)/convert(P,x.den))::T
function numerator_denominator_promoted(x)
y = unsafe_rational(numerator(x), denominator(x))
(numerator(y), denominator(y))
end

function rational_to_floating_point(::Type{F}, x, rm, prec) where {F}
nd = numerator_denominator_promoted(x)
RationalToFloat.to_floating_point(F, nd..., rm, prec)::F
end

function (::Type{F})(x::Rational, rm::RoundingMode = RoundNearest) where {F<:AbstractFloat}
rational_to_floating_point(F, x, rm, precision(F))::F
end

function AbstractFloat(x::Q) where {Q<:Rational}
T = float(Q)
T(x)::T::AbstractFloat
end

function Rational{T}(x::AbstractFloat) where T<:Integer
Expand Down
214 changes: 214 additions & 0 deletions base/rational_to_float.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module RationalToFloat

const Rnd = Base.Rounding

# Performance optimization. Unlike raw `<<` or `>>>`, this is supposed
# to compile to a single instruction, because the semantics correspond
# to what hardware usually provides.
function machine_shift(shift::S, a::T, b) where {S,T<:Base.BitInteger}
@inline begin
mask = 8*sizeof(T) - 1
c = b & mask
shift(a, c)
end
end

machine_shift(::S, a::Bool, ::Any) where {S} = error("unsupported")

# Fallback for `BigInt` etc.
machine_shift(shift::S, a, b) where {S} = shift(a, b)

# Arguments are positive integers.
function div_significand_with_remainder(num, den, minimum_significand_size)
clamped = x -> max(zero(x), x)::typeof(x)
bw = Base.top_set_bit # bit width
shift = clamped(minimum_significand_size + bw(den) - bw(num) + 0x2)
t = machine_shift(<<, num, shift)
(divrem(t, den, RoundToZero)..., shift)
end

# `divrem(n, 1<<k, RoundToZero)`
function divrem_2(n, k)
quo = machine_shift(>>>, n, k)
tmp = machine_shift(<<, quo, k)
rem = n - tmp
(quo, rem)
end

function to_float_components_impl(num, den, precision, max_subnormal_exp)
# `+1` because we need an extra, "round", bit for some rounding modes.
#
# TODO: as a performance optimization, only do this when required
# by the rounding mode
prec_p_1 = precision + true

(quo0, rem0, shift) = div_significand_with_remainder(num, den, prec_p_1)
width = Base.top_set_bit(quo0)
excess_width = width - prec_p_1
exp = width - shift - true

exp_underflow = if isnothing(max_subnormal_exp)
zero(exp)
else
let d = max_subnormal_exp - exp, T = typeof(d), z = zero(d)::T
(signbit(d) ? z : d + true)::T
end
end

(quo1, rem1) = divrem_2(quo0, exp_underflow + excess_width)
integral_significand = quo1 >>> true
round_bit = quo1 % Bool
sticky_bit = !iszero(rem1) | !iszero(rem0)

(; integral_significand, exponent = exp, round_bit, sticky_bit)
end

struct RoundingIncrementHelper
final_bit::Bool
round_bit::Bool
sticky_bit::Bool
end

(h::RoundingIncrementHelper)(::Rnd.FinalBit) = h.final_bit
(h::RoundingIncrementHelper)(::Rnd.RoundBit) = h.round_bit
(h::RoundingIncrementHelper)(::Rnd.StickyBit) = h.sticky_bit

function to_float_components_rounded(num, den, precision, max_subnormal_exp, romo, sign_bit)
overflows = (x, p) -> x == machine_shift(<<, one(x), p)
t = to_float_components_impl(num, den, precision, max_subnormal_exp)
raw_significand = t.integral_significand
rh = RoundingIncrementHelper(raw_significand % Bool, t.round_bit, t.sticky_bit)
incr = Rnd.correct_rounding_requires_increment(rh, romo, sign_bit)
rounded = raw_significand + incr
(integral_significand, exponent) = let exp = t.exponent
if overflows(rounded, precision)
(rounded >>> true, exp + true)
else
(rounded, exp)
end
end
(; integral_significand, exponent)
end

function to_float_components(::Type{T}, num, den, precision, max_subnormal_exp, romo, sb) where {T}
to_float_components_rounded(abs(T(num)), den, precision, max_subnormal_exp, romo, sb)
end

function to_floating_point_fallback(::Type{T}, ::Type{S}, num, den, rm, prec) where {T,S}
num_is_zero = iszero(num)
den_is_zero = iszero(den)
sb = signbit(num)
is_zero = num_is_zero & !den_is_zero
is_inf = !num_is_zero & den_is_zero
is_regular = !num_is_zero & !den_is_zero
if is_regular
let
c = to_float_components(S, num, den, prec, nothing, rm, sb)
exp = c.exponent
signif = T(c.integral_significand)::T
let x = ldexp(signif, exp - prec + true)::T
sb ? -x : x
end::T
end
else
if is_zero
zero(T)::T
elseif is_inf
T(Inf)::T
else
T(NaN)::T
end
end::T
end

function to_floating_point_impl(::Type{T}, ::Type{S}, num, den, rm, prec) where {T,S}
to_floating_point_fallback(T, S, num, den, rm, prec)
end

function to_floating_point_impl(::Type{T}, ::Type{S}, num, den, rm, prec) where {T<:Base.IEEEFloat,S}
num_is_zero = iszero(num)
den_is_zero = iszero(den)
sb = signbit(num)
is_zero = num_is_zero & !den_is_zero
is_inf = !num_is_zero & den_is_zero
is_regular = !num_is_zero & !den_is_zero
(rm_is_to_zero, rm_is_from_zero) = if Rnd.rounds_to_nearest(rm)
(false, false)
else
let from = Rnd.rounds_away_from_zero(rm, sb)
(!from, from)
end
end::NTuple{2,Bool}
exp_max = Base.exponent_max(T)
exp_min = Base.exponent_min(T)
ieee_repr = Base.ieee754_representation
repr_zero = ieee_repr(T, sb, Val(:zero))
repr_inf = ieee_repr(T, sb, Val(:inf))
repr_nan = ieee_repr(T, sb, Val(:nan))
U = typeof(repr_zero)
repr_zero::U
repr_inf::U
repr_nan::U

ret_u = if is_regular
let
c = let e = exp_min - 1
to_float_components(S, num, den, prec, e, rm, sb)
end
exp = c.exponent
exp_diff = exp - exp_min
is_normal = 0 exp_diff
exp_is_huge_p = exp_max < exp
exp_is_huge_n = signbit(exp_diff + prec)
rounds_to_inf = exp_is_huge_p & !rm_is_to_zero
rounds_to_zero = exp_is_huge_n & !rm_is_from_zero

if !rounds_to_zero & !exp_is_huge_p
let signif = (c.integral_significand % U) & Base.significand_mask(T)
exp_field = (max(exp_diff, zero(exp_diff)) + is_normal) % U
ieee_repr(T, sb, exp_field, signif)::U
end
elseif rounds_to_zero
repr_zero
elseif rounds_to_inf
repr_inf
else
ieee_repr(T, sb, Val(:omega))
end
end
else
if is_zero
repr_zero
elseif is_inf
repr_inf
else
repr_nan
end
end::U

reinterpret(T, ret_u)::T
end

# `BigInt` is a safe default.
to_float_promote_type(::Type{F}, ::Type{S}) where {F,S} = BigInt

const BitIntegerOrBool = Union{Bool,Base.BitInteger}

# As an optimization, use an integer type narrower than `BigInt` when possible.
function to_float_promote_type(::Type{F}, ::Type{S}) where {F<:Base.IEEEFloat,S<:BitIntegerOrBool}
Max = if sizeof(F) sizeof(S)
S
else
(S <: Signed) ? Base.inttype(F) : Base.uinttype(F)
end
widen(Max)
end

function to_floating_point(::Type{F}, num::T, den::T, rm, prec) where {F,T}
S = to_float_promote_type(F, T)
to_floating_point_impl(F, S, num, den, rm, prec)
end

end
2 changes: 1 addition & 1 deletion test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const TESTNAMES = [
"char", "strings", "triplequote", "unicode", "intrinsics",
"dict", "hashing", "iobuffer", "staged", "offsetarray",
"arrayops", "tuple", "reduce", "reducedim", "abstractarray",
"intfuncs", "simdloop", "vecelement", "rational",
"intfuncs", "simdloop", "vecelement", "rational", "rational_to_float",
"bitarray", "copy", "math", "fastmath", "functional", "iterators",
"operators", "ordering", "path", "ccall", "parse", "loading", "gmp",
"sorting", "spawn", "backtrace", "exceptions",
Expand Down
Loading

0 comments on commit 314db19

Please sign in to comment.