Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward over reverse drops custom rule if corresponding function is inlined #1795

Open
danielwe opened this issue Sep 5, 2024 · 2 comments

Comments

@danielwe
Copy link
Contributor

danielwe commented Sep 5, 2024

My custom reverse rule works as expected in 1st order reverse mode, but I need to give the corresponding function @noinline tag for it to be picked up in 2nd order forward over reverse. In the MWE below I've introduced a bug in the rule such that both the gradient and the hv product should be different between g and g_custom; gradients are always different, but hv products are only different when I add @noinline.

using Enzyme

#=@noinline=# f(x) = sum(abs2, x)
#=@noinline=# f_custom(x) = sum(abs2, x)

g(x) = cos(f(x))
g_custom(x) = cos(f_custom(x))

function dg_deferred!(dx, x)
    make_zero!(dx)
    autodiff_deferred(Reverse, g, Active, Duplicated(x, dx))
    return nothing
end

function dg_custom_deferred!(dx, x)
    make_zero!(dx)
    autodiff_deferred(Reverse, g_custom, Active, Duplicated(x, dx))
    return nothing
end

function EnzymeRules.augmented_primal(
    config::EnzymeRules.Config, f::Const{typeof(f_custom)}, ::Type{<:Active}, x::Duplicated
)
    tape = EnzymeRules.overwritten(config)[2] ? copy(x.val) : nothing
    primal = EnzymeRules.needs_primal(config) ? f.val(x.val) : nothing
    return EnzymeRules.AugmentedReturn(primal, nothing#=shadow=#, tape)
end

function EnzymeRules.reverse(
    config::EnzymeRules.Config,
    ::Const{typeof(f_custom)},
    dret::Active,
    tape,
    x::Duplicated,
)
    xval = EnzymeRules.overwritten(config)[2] ? tape : x.val
    x.dval .= (2dret.val) .* xval
    x.dval .^= 2  # Deliberate bug as signature of custom rule 🐛
    return (nothing,)
end

x = [2.0]
dx, dx_custom = make_zero(x), make_zero(x)

v = first(onehot(x))
hv, hv_custom = make_zero(v), make_zero(v)

# gradients
dg_deferred!(dx, x)
@show dx

dg_custom_deferred!(dx_custom, x)
@show dx_custom

# hvps
autodiff(Forward, dg_deferred!, Const, Duplicated(dx, hv), Duplicated(x, v))
@show hv

autodiff(
    Forward,
    dg_custom_deferred!,
    Const,
    Duplicated(dx_custom, hv_custom),
    Duplicated(x, v),
)
@show hv_custom

Output as written: different gradients, equal hv products.

dx = [3.027209981231713]
dx_custom = [9.164000270468907]
hv = [11.971902924433648]
hv_custom = [11.971902924433648]

Output with @noinline: both gradients and hv products different.

dx = [3.027209981231713]
dx_custom = [9.164000270468907]
hv = [11.971902924433648]
hv_custom = [72.48292805436535]
@wsmoses
Copy link
Member

wsmoses commented Dec 7, 2024

@danielwe with the just landed fixes to nesting (and in particular not calling deferred anymore on the inside but regular autodiff), does this still err?

@danielwe
Copy link
Contributor Author

danielwe commented Jan 9, 2025

Finally got around to testing this. Updating the MWE to conform to API changes (Config -> RevConfig), it now works correctly if I replace autodiff_deferred with just autodiff. Awesome! (The original bug report happened before autodiff was allowed in 2nd order AD.)

However, if I stick to autodiff_deferred (adding the Const annotation to the function as required by the simplified API), I still get the error unless I use @noinline as described in the OP.

So if autodiff_deferred is deprecated as public API this issue can be closed (and ideally docs updated accordingly). If end users should still be able to use autodiff_deferred and get correct results, there's a remaining issue here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants