Skip to content

Commit

Permalink
random: introduce State to formalize hooking into rand machinery
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Oct 5, 2017
1 parent f2fd1f8 commit 004b12c
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 308 deletions.
172 changes: 94 additions & 78 deletions base/random/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

## RandomDevice

const BoolBitIntegerType = Union{Type{Bool},Base.BitIntegerType}
const BoolBitIntegerArray = Union{Array{Bool},Base.BitIntegerArray}

StateTypes(U::Union) = Union{map(T->StateType{T}, Base.uniontypes(U))...}
const StateBoolBitInteger = StateTypes(Union{Bool, Base.BitInteger})

if Sys.iswindows()
struct RandomDevice <: AbstractRNG
Expand All @@ -12,15 +13,9 @@ if Sys.iswindows()
RandomDevice() = new(Vector{UInt128}(1))
end

function rand(rd::RandomDevice, T::BoolBitIntegerType)
function rand(rd::RandomDevice, st::StateBoolBitInteger)
rand!(rd, rd.buffer)
@inbounds return rd.buffer[1] % T
end

function rand!(rd::RandomDevice, A::BoolBitIntegerArray)
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
A, sizeof(A))
A
@inbounds return rd.buffer[1] % st[]
end
else # !windows
struct RandomDevice <: AbstractRNG
Expand All @@ -31,10 +26,22 @@ else # !windows
new(open(unlimited ? "/dev/urandom" : "/dev/random"), unlimited)
end

rand(rd::RandomDevice, T::BoolBitIntegerType) = read( rd.file, T)
rand!(rd::RandomDevice, A::BoolBitIntegerArray) = read!(rd.file, A)
rand(rd::RandomDevice, st::StateBoolBitInteger) = read( rd.file, st[])
end # os-test

# NOTE: this can't be put in within the if-else block above
for T in (Bool, Base.BitInteger_types...)
if Sys.iswindows()
@eval function rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T})
ccall((:SystemFunction036, :Advapi32), stdcall, UInt8, (Ptr{Void}, UInt32),
A, sizeof(A))
A
end
else
@eval rand!(rd::RandomDevice, A::Array{$T}, ::StateType{$T}) = read!(rd.file, A)
end
end

"""
RandomDevice()
Expand All @@ -49,7 +56,7 @@ srand(rng::RandomDevice) = rng

### generation of floats

rand(r::RandomDevice, I::FloatInterval) = rand_generic(r, I)
rand(r::RandomDevice, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[])


## MersenneTwister
Expand Down Expand Up @@ -229,30 +236,30 @@ rand_ui23_raw(r::MersenneTwister) = rand_ui52_raw(r)

#### floats

rand(r::MersenneTwister, I::FloatInterval_64) = (reserve_1(r); rand_inbounds(r, I))
rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval_64}) = (reserve_1(r); rand_inbounds(r, st[]))

rand(r::MersenneTwister, I::FloatInterval) = rand_generic(r, I)
rand(r::MersenneTwister, st::StateTrivial{<:FloatInterval}) = rand_generic(r, st[])

#### integers

rand(r::MersenneTwister,
::Type{T}) where {T<:Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}} =
rand_ui52_raw(r) % T
T::StateTypes(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand_ui52_raw(r) % T[]

function rand(r::MersenneTwister, ::Type{UInt64})
function rand(r::MersenneTwister, ::StateType{UInt64})
reserve(r, 2)
rand_ui52_raw_inbounds(r) << 32 rand_ui52_raw_inbounds(r)
end

function rand(r::MersenneTwister, ::Type{UInt128})
function rand(r::MersenneTwister, ::StateType{UInt128})
reserve(r, 3)
xor(rand_ui52_raw_inbounds(r) % UInt128 << 96,
rand_ui52_raw_inbounds(r) % UInt128 << 48,
rand_ui52_raw_inbounds(r))
end

rand(r::MersenneTwister, ::Type{Int64}) = reinterpret(Int64, rand(r, UInt64))
rand(r::MersenneTwister, ::Type{Int128}) = reinterpret(Int128, rand(r, UInt128))
rand(r::MersenneTwister, ::StateType{Int64}) = reinterpret(Int64, rand(r, UInt64))
rand(r::MersenneTwister, ::StateType{Int128}) = reinterpret(Int128, rand(r, UInt128))

#### arrays of floats

Expand All @@ -278,16 +285,17 @@ function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float6
A
end

rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)
rand!(r::MersenneTwister, A::AbstractArray{Float64}, I::StateTrivial{<:FloatInterval_64}) =
rand_AbstractArray_Float64!(r, A, length(A), I[])

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
dsfmt_fill_array_close_open!(s, A, n)

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
dsfmt_fill_array_close1_open2!(s, A, n)

function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
I::FloatInterval_64=CloseOpen())
function _rand!(r::MersenneTwister, A::Array{Float64}, n::Int,
I::FloatInterval_64)
# depending on the alignment of A, the data written by fill_array! may have
# to be left-shifted by up to 15 bytes (cf. unsafe_copy! below) for
# reproducibility purposes;
Expand Down Expand Up @@ -317,65 +325,63 @@ function rand!(r::MersenneTwister, A::Array{Float64}, n::Int=length(A),
A
end

rand!(r::MersenneTwister, A::Array{Float64}, st::StateTrivial{<:FloatInterval_64}) =
_rand!(r, A, length(A), st[])

mask128(u::UInt128, ::Type{Float16}) =
(u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00

mask128(u::UInt128, ::Type{Float32}) =
(u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000

function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}},
::Close1Open2_64)
T = eltype(A)
n = length(A)
n128 = n * sizeof(T) ÷ 16
Base.@gc_preserve A rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
2*n128, Close1Open2())
# FIXME: This code is completely invalid!!!
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
@inbounds for i in 1:n128
u = A128[i]
u ⊻= u << 26
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
# the bit xor, are:
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
# the bits needing to be random are
# [1:10, 17:26, 33:42, 49:58] (for Float16)
# [1:23, 33:55] (for Float32)
# this is obviously satisfied on the 32 low bits side, and on the high side,
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
# (which are discarded on the low side)
# this is similar for the 64 high bits of u
A128[i] = mask128(u, T)
end
for i in 16*n128÷sizeof(T)+1:n
@inbounds A[i] = rand(r, T) + oneunit(T)
for T in (Float16, Float32)
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{Close1Open2{$T}})
n = length(A)
n128 = n * sizeof($T) ÷ 16
Base.@gc_preserve A _rand!(r, unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2*n128),
2*n128, Close1Open2())
# FIXME: This code is completely invalid!!!
A128 = unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128)
@inbounds for i in 1:n128
u = A128[i]
u ⊻= u << 26
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+"
# the bit xor, are:
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
# the bits needing to be random are
# [1:10, 17:26, 33:42, 49:58] (for Float16)
# [1:23, 33:55] (for Float32)
# this is obviously satisfied on the 32 low bits side, and on the high side,
# the entropy comes from bits 33:52 of A128[i] and then from bits 27:32
# (which are discarded on the low side)
# this is similar for the 64 high bits of u
A128[i] = mask128(u, $T)
end
for i in 16*n128÷sizeof($T)+1:n
@inbounds A[i] = rand(r, $T) + oneunit($T)
end
A
end
A
end

function rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}, ::CloseOpen_64)
rand!(r, A, Close1Open2())
I32 = one(Float32)
for i in eachindex(A)
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateTrivial{CloseOpen{$T}})
rand!(r, A, Close1Open2($T))
I32 = one(Float32)
for i in eachindex(A)
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
end
A
end
A
end

rand!(r::MersenneTwister, A::Union{Array{Float16},Array{Float32}}) =
rand!(r, A, CloseOpen())

#### arrays of integers

function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
if n > length(A)
throw(BoundsError(A,n))
end
function rand!(r::MersenneTwister, A::Array{UInt128}, ::StateType{UInt128})
n::Int=length(A)
# FIXME: This code is completely invalid!!!
Af = unsafe_wrap(Array, convert(Ptr{Float64}, pointer(A)), 2n)
i = n
while true
rand!(r, Af, 2i, Close1Open2())
_rand!(r, Af, 2i, Close1Open2())
n < 5 && break
i = 0
@inbounds while n-i >= 5
Expand All @@ -396,17 +402,18 @@ function rand!(r::MersenneTwister, A::Array{UInt128}, n::Int=length(A))
A
end

# A::Array{UInt128} will match the specialized method above
function rand!(r::MersenneTwister, A::Base.BitIntegerArray)
n = length(A)
T = eltype(A)
n128 = n * sizeof(T) ÷ 16
# FIXME: This code is completely invalid!!!
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
for i = 16*n128÷sizeof(T)+1:n
@inbounds A[i] = rand(r, T)
for T in Base.BitInteger_types
T === UInt128 && continue
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::StateType{$T})
n = length(A)
n128 = n * sizeof($T) ÷ 16
# FIXME: This code is completely invalid!!!
rand!(r, unsafe_wrap(Array, convert(Ptr{UInt128}, pointer(A)), n128))
for i = 16*n128÷sizeof($T)+1:n
@inbounds A[i] = rand(r, $T)
end
A
end
A
end

#### from a range
Expand All @@ -418,7 +425,9 @@ function rand_lteq(r::AbstractRNG, randfun, u::U, mask::U) where U<:Integer
end
end

function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInteger64,Bool}
function rand(rng::MersenneTwister,
st::StateTrivial{UnitRange{T}}) where T<:Union{Base.BitInteger64,Bool}
r = st[]
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = last(r) % UInt64 - first(r) % UInt64
bw = (64 - leading_zeros(m)) % UInt # bit-width
Expand All @@ -428,7 +437,9 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Base.BitInte
(x + first(r) % UInt64) % T
end

function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt128}
function rand(rng::MersenneTwister,
st::StateTrivial{UnitRange{T}}) where T<:Union{Int128,UInt128}
r = st[]
isempty(r) && throw(ArgumentError("range must be non-empty"))
m = (last(r)-first(r)) % UInt128
bw = (128 - leading_zeros(m)) % UInt # bit-width
Expand All @@ -439,6 +450,11 @@ function rand(rng::MersenneTwister, r::UnitRange{T}) where T<:Union{Int128,UInt1
x % T + first(r)
end

for T in (Bool, Base.BitInteger_types...) # eval because of ambiguity otherwise
@eval State(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) =
StateTrivial(r)
end


### randjump

Expand Down
Loading

0 comments on commit 004b12c

Please sign in to comment.