diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 96f774f69e..238f7f7b03 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -964,3 +964,65 @@ function EnzymeRules.reverse( ) return () end + +function EnzymeRules.forward(config::EnzymeRules.FwdConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + Ty.val(rng.val, dst.val, smpl.val) + if RT <: Duplicated + fill!(dst.dval, 0) + Duplicated(dst.val, dst.dval) + elseif RT <: Const + dst.val + elseif RT <: DuplicatedNoNeed + fill!(dst.dval, 0) + dst.dval + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end + if RT <: BatchDuplicated + BatchDuplicated(dst.val, dst.dval) + else + dst.dval + end + end +end + +function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + Ty.val(rng.val, dst.val, smpl.val) + if RT <: Duplicated || RT <: DuplicatedNoNeed + fill!(dst.dval, 0) + dst.dval + elseif RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + fill!(dst.dval[i], 0) + nothing + end + end + return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? dst.val : nothing, EnzymeRules.needs_shadow(config) ? dst.dval : nothing, nothing) +end + +function EnzymeRules.reverse(config::EnzymeRules.RevConfig, + Ty::Const{typeof(Random.rand!)}, + RT::Type, + tape, + rng::Annotation{rngty}, + dst::Annotation{<:Array{FT}}, + smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}}, + ) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}} + return (nothing, nothing, nothing) +end