From 41692be64a2fc062865711383b1c5b5e3456d25b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 30 Oct 2024 01:05:41 -0400 Subject: [PATCH] Fix empty forward gradient --- src/Enzyme.jl | 3 ++- src/internal_rules.jl | 10 +++++++--- src/rules/customrules.jl | 28 +++++++++++++++++++++++----- test/sugar.jl | 6 ++++++ 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c3769e35ac..2e8643744b 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1563,6 +1563,7 @@ end end end end +@inline onehot(x::Tuple{}) = () @inline function onehot(x::NTuple{N,T}) where {T,N} onehot(NTuple{N,T}) end @@ -2141,7 +2142,7 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1])) end end else - :(specialize_output(TupleArray($tmp, size($arg)), $(vals[1]))) + tmp end else tmp diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 6fe70df8cf..539223ffa5 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -1303,18 +1303,22 @@ end function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfig, func::Const{Colon}, - ::Type{<:Active}, + ::Type{RT}, start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat}, -) +) where RT <: Active if EnzymeRules.needs_primal(config) primal = func.val(start.val, step.val, stop.val) else primal = nothing end - return EnzymeRules.AugmentedReturn(primal, nothing, nothing) + return EnzymeRules.AugmentedReturn{ + EnzymeRules.primal_type(config, RT), + Nothing, + Nothing + }(primal, nothing, nothing) end function EnzymeRules.reverse( diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index fbd646866b..96661849b2 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -968,6 +968,7 @@ end TapeT = Nothing + if ( aug_RT <: EnzymeRules.AugmentedReturn || aug_RT <: EnzymeRules.AugmentedReturnFlexShadow @@ -996,7 +997,7 @@ end else TapeT = Any end - + mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) llvmf = nothing @@ -1027,8 +1028,16 @@ end rkwfunc = Core.kwfunc(EnzymeRules.reverse) if EnzymeRules.isapplicable(rkwfunc, rev_TT; world) @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT - llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) - rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + try + llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) + rev_RT = Core.Compiler.return_type(rkwfunc, rev_TT, world) + catch e + rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} + applicablefn = false + end else rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) @@ -1039,8 +1048,17 @@ end else if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) @safe_debug "Applying custom reverse rule" TT = rev_TT - llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) - rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + try + llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) + rev_RT = Core.Compiler.return_type(EnzymeRules.reverse, rev_TT, world) + catch e + rev_TT = + Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} + llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} + applicablefn = false + end else rev_TT = Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} diff --git a/test/sugar.jl b/test/sugar.jl index 097472ab22..340a54c569 100644 --- a/test/sugar.jl +++ b/test/sugar.jl @@ -4,6 +4,12 @@ using LinearAlgebra mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1] mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]] +@testset "Forward Empty Gradient" begin + inp = Float64[] + res = gradient(Forward, sin, inp) + @test res[1] === inp +end + @testset "Forward Multi-Arg Gradient" begin res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1]) @test res[1] ≈ [3.1, 2.7]