Skip to content

Commit

Permalink
Fix rand set (#1833)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 15, 2024
1 parent e10ad8c commit 7faa410
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -964,3 +964,65 @@ function EnzymeRules.reverse(
)
return ()
end

function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
Ty::Const{typeof(Random.rand!)},
RT::Type,
rng::Annotation{rngty},
dst::Annotation{<:Array{FT}},
smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}},
) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}}
Ty.val(rng.val, dst.val, smpl.val)
if RT <: Duplicated
fill!(dst.dval, 0)
Duplicated(dst.val, dst.dval)
elseif RT <: Const
dst.val
elseif RT <: DuplicatedNoNeed
fill!(dst.dval, 0)
dst.dval
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
fill!(dst.dval[i], 0)
nothing
end
if RT <: BatchDuplicated
BatchDuplicated(dst.val, dst.dval)
else
dst.dval
end
end
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
Ty::Const{typeof(Random.rand!)},
RT::Type,
rng::Annotation{rngty},
dst::Annotation{<:Array{FT}},
smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}},
) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}}
Ty.val(rng.val, dst.val, smpl.val)
if RT <: Duplicated || RT <: DuplicatedNoNeed
fill!(dst.dval, 0)
dst.dval
elseif RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
fill!(dst.dval[i], 0)
nothing
end
end
return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? dst.val : nothing, EnzymeRules.needs_shadow(config) ? dst.dval : nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
Ty::Const{typeof(Random.rand!)},
RT::Type,
tape,
rng::Annotation{rngty},
dst::Annotation{<:Array{FT}},
smpl::Annotation{<:Random.SamplerTrivial{Random.CloseOpen01{FT}}},
) where {rngty <: Union{TaskLocalRNG, Xoshiro}, FT <: Union{Float32, Float64}}
return (nothing, nothing, nothing)
end

0 comments on commit 7faa410

Please sign in to comment.