diff --git a/stdlib/Random/src/RNGs.jl b/stdlib/Random/src/RNGs.jl index 7b1d0684ba0b8..c5b06f9dfc4d7 100644 --- a/stdlib/Random/src/RNGs.jl +++ b/stdlib/Random/src/RNGs.jl @@ -137,6 +137,7 @@ true MersenneTwister(seed=nothing) = srand(MersenneTwister(Vector{UInt32}(), DSFMT_state()), seed) + function copy!(dst::MersenneTwister, src::MersenneTwister) copyto!(resize!(dst.seed, length(src.seed)), src.seed) copy!(dst.state, src.state) @@ -161,6 +162,14 @@ copy(src::MersenneTwister) = hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI)) +function fillcache_zeros!(r::MersenneTwister) + # the use of this function is not strictly necessary, but it makes + # comparing two MersenneTwister RNGs easier + fill!(r.vals, 0.0) + fill!(r.ints, zero(UInt128)) + r +end + ### low level API @@ -271,9 +280,8 @@ function srand(r::MersenneTwister, seed::Vector{UInt32}) copyto!(resize!(r.seed, length(seed)), seed) dsfmt_init_by_array(r.state, r.seed) mt_setempty!(r) - fill!(r.vals, 0.0) # not strictly necessary, but why not, makes comparing two MT easier mt_setempty!(r, UInt128) - fill!(r.ints, 0) + fillcache_zeros!(r) return r end @@ -563,7 +571,7 @@ randjump(r::MersenneTwister, steps::Integer, len::Integer) = _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X) = - MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly)) + fillcache_zeros!(MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))) function _randjump(mt::MersenneTwister, jumppoly::DSFMT.GF2X, len::Integer) mts = MersenneTwister[] diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index 31b483854945f..fe0224f018336 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -498,8 +498,8 @@ let mta = MersenneTwister(42), mtb = MersenneTwister(42) @test sprand(mta,10,10,0.3) == sprand(mtb,10,10,0.3) end -# test MersenneTwister polynomial generation and jump -let seed = rand(UInt) +@testset "MersenneTwister polynomial generation and jump" begin + seed = rand(UInt) mta = MersenneTwister(seed) mtb = MersenneTwister(seed) step = 25000*2 @@ -520,6 +520,11 @@ let seed = rand(UInt) for x in (rand(mts[k], Float64) for j=1:step, k=1:size) @test rand(mtb, Float64) == x end + + @testset "generated RNGs are in a deterministic state (relatively to ==)" begin + m = MersenneTwister() + @test randjump(m, 25000, 2) == randjump(m, 25000, 2) + end end # test that the following is not an error (#16925)