Skip to content

Commit

Permalink
Fix empty forward gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 30, 2024
1 parent 12c1abb commit 41692be
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,7 @@ end
end
end
end
@inline onehot(x::Tuple{}) = ()
@inline function onehot(x::NTuple{N,T}) where {T,N}
onehot(NTuple{N,T})
end
Expand Down Expand Up @@ -2141,7 +2142,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
end
end
else
:(specialize_output(TupleArray($tmp, size($arg)), $(vals[1])))
tmp
end
else
tmp
Expand Down
10 changes: 7 additions & 3 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1303,18 +1303,22 @@ end
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::Const{Colon},
::Type{<:Active},
::Type{RT},
start::Annotation{<:AbstractFloat},
step::Annotation{<:AbstractFloat},
stop::Annotation{<:AbstractFloat},
)
) where RT <: Active

if EnzymeRules.needs_primal(config)
primal = func.val(start.val, step.val, stop.val)
else
primal = nothing
end
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
return EnzymeRules.AugmentedReturn{
EnzymeRules.primal_type(config, RT),
Nothing,
Nothing
}(primal, nothing, nothing)
end

function EnzymeRules.reverse(
Expand Down
28 changes: 23 additions & 5 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ end

TapeT = Nothing


if (
aug_RT <: EnzymeRules.AugmentedReturn ||
aug_RT <: EnzymeRules.AugmentedReturnFlexShadow
Expand Down Expand Up @@ -996,7 +997,7 @@ end
else
TapeT = Any
end

mod = LLVM.parent(LLVM.parent(LLVM.parent(orig)))

llvmf = nothing
Expand Down Expand Up @@ -1027,8 +1028,16 @@ end
rkwfunc = Core.kwfunc(EnzymeRules.reverse)
if EnzymeRules.isapplicable(rkwfunc, rev_TT; world)
@safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT
llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world)
rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world)
try
llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world)
rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world)
catch e
rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
else
rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
Expand All @@ -1039,8 +1048,17 @@ end
else
if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world)
@safe_debug "Applying custom reverse rule" TT = rev_TT
llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world)
rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world)
try
llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world)
rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world)
catch e
rev_TT =
Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
else
rev_TT =
Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...}
Expand Down
6 changes: 6 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ using LinearAlgebra
mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1]
mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]]

@testset "Forward Empty Gradient" begin
inp = Float64[]
res = gradient(Forward, sin, inp)
@test res[1] === inp
end

@testset "Forward Multi-Arg Gradient" begin
res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1])
@test res[1] [3.1, 2.7]
Expand Down

0 comments on commit 41692be

Please sign in to comment.