diff --git a/src/Lux.jl b/src/Lux.jl index 1551ad82a..57525eed3 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -12,6 +12,7 @@ using ConcreteStructs: @concrete using ConstructionBase: ConstructionBase using EnzymeCore: EnzymeCore, EnzymeRules using FastClosures: @closure +using ForwardDiff: ForwardDiff using Functors: Functors, fmap using GPUArraysCore: GPUArraysCore, @allowscalar using LossFunctions: LossFunctions diff --git a/src/helpers/nested_ad.jl b/src/helpers/nested_ad.jl index e56baa70f..ea0670a07 100644 --- a/src/helpers/nested_ad.jl +++ b/src/helpers/nested_ad.jl @@ -49,7 +49,7 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_gradient_call), grad_fn::G, f::F, x, y) where {G, F} - if !AUTOMATIC_NESTED_AD_SWITCHING + @static if !AUTOMATIC_NESTED_AD_SWITCHING return CRC.rrule_via_ad( cfg, __internal_ad_gradient_call_no_custom_rrule, grad_fn, f, x, y) end @@ -76,7 +76,7 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_pullback_call), pullback_fn::P, f::F, x, y, u) where {P, F} - if !AUTOMATIC_NESTED_AD_SWITCHING + @static if !AUTOMATIC_NESTED_AD_SWITCHING return CRC.rrule_via_ad( cfg, __internal_ad_pullback_call_no_custom_rrule, pullback_fn, f, x, y, u) end @@ -110,7 +110,7 @@ end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_jacobian_call), jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} - if !AUTOMATIC_NESTED_AD_SWITCHING + @static if !AUTOMATIC_NESTED_AD_SWITCHING return CRC.rrule_via_ad( cfg, __internal_ad_jacobian_call_no_custom_rrule, jac_fn, grad_fn, f, x, y) end