Skip to content

Commit

Permalink
fallback randn/randexp for AbstractFloat
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Mar 23, 2022
1 parent 62e0729 commit 3ca78c1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Standard library changes

#### Random

* `randn` and `randexp` now work for any `AbstractFloat` type defining `rand` ([#44713]).

#### REPL

#### SparseArrays
Expand Down
13 changes: 13 additions & 0 deletions stdlib/Random/src/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ randn(rng::AbstractRNG, ::Type{Complex{T}}) where {T<:AbstractFloat} =
Complex{T}(SQRT_HALF * randn(rng, T), SQRT_HALF * randn(rng, T))


### fallback randn for float types defining rand:
function randn(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat}
# Marsaglia polar variant of Box–Muller transform:
while true
x, y = 2rand(rng, T)-1, 2rand(rng, T)-1
0 < (s = x^2 + y^2) < 1 || continue
return x * sqrt(-2log(s)/s) # and/or y * sqrt(...)
end
end

## randexp

"""
Expand Down Expand Up @@ -137,6 +147,9 @@ end
end
end

### fallback randexp for float types defining rand:
randexp(rng::AbstractRNG, ::Type{T}) where {T<:AbstractFloat} =
-log(rand(rng, T))

## arrays & other scalar methods

Expand Down
32 changes: 26 additions & 6 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,32 @@ let a = [rand(RandomDevice(), UInt128) for i=1:10]
@test reduce(|, a)>>>64 != 0
end

# wrapper around Float64 to check fallback random generators
struct FakeFloat64 <: AbstractFloat
x::Float64
end
Base.rand(rng::AbstractRNG, ::Random.SamplerTrivial{Random.CloseOpen01{FakeFloat64}}) = FakeFloat64(rand(rng))
for f in (:sqrt, :log, :one, :zero, :abs, :+, :-)
@eval Base.$f(x::FakeFloat64) = FakeFloat64($f(x.x))
end
for f in (:+, :-, :*, :/)
@eval begin
Base.$f(x::FakeFloat64, y::FakeFloat64) = FakeFloat64($f(x.x,y.x))
Base.$f(x::FakeFloat64, y::Real) = FakeFloat64($f(x.x,y))
Base.$f(x::Real, y::FakeFloat64) = FakeFloat64($f(x,y.x))
end
end
for f in (:<, :<=, :>, :>=, :(==), :(!=))
@eval begin
Base.$f(x::FakeFloat64, y::FakeFloat64) = $f(x.x,y.x)
Base.$f(x::FakeFloat64, y::Real) = $f(x.x,y)
Base.$f(x::Real, y::FakeFloat64) = $f(x,y.x)
end
end

# test all rand APIs
for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
ftypes = [Float16, Float32, Float64]
ftypes = [Float16, Float32, Float64, FakeFloat64, BigFloat]
cftypes = [ComplexF16, ComplexF32, ComplexF64, ftypes...]
types = [Bool, Char, BigFloat, Base.BitInteger_types..., ftypes...]
randset = Set(rand(Int, 20))
Expand Down Expand Up @@ -406,15 +429,12 @@ for rng in ([], [MersenneTwister(0)], [RandomDevice()], [Xoshiro()])
rand!(rng..., BitMatrix(undef, 2, 3)) ::BitArray{2}

# Test that you cannot call randn or randexp with non-Float types.
for r in [randn, randexp, randn!, randexp!]
local r
for r in [randn, randexp]
@test_throws MethodError r(Int)
@test_throws MethodError r(Int32)
@test_throws MethodError r(Bool)
@test_throws MethodError r(String)
@test_throws MethodError r(AbstractFloat)
# TODO(#17627): Consider adding support for randn(BigFloat) and removing this test.
@test_throws MethodError r(BigFloat)
@test_throws ArgumentError r(AbstractFloat)

@test_throws MethodError r(Int64, (2,3))
@test_throws MethodError r(String, 1)
Expand Down

0 comments on commit 3ca78c1

Please sign in to comment.