From 203ba5a8a6177b676828a782d9700737984a1f15 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 27 Sep 2024 19:42:38 -0500 Subject: [PATCH] Fix deferred any active return --- Project.toml | 2 +- lib/EnzymeCore/Project.toml | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 24 +-- src/Enzyme.jl | 245 +++++++++++++++++++------------ test/abi.jl | 14 ++ 5 files changed, 182 insertions(+), 105 deletions(-) diff --git a/Project.toml b/Project.toml index fd0882e97e..0d5c846cd7 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.3" +EnzymeCore = "0.8.4" Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 3a871b930c..2e45d2c2f6 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.3" +version = "0.8.4" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 3231674de5..54244c0544 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -269,21 +269,21 @@ Reverse mode differentiation. - `Width`: Batch Size (0 if to be automatically derived) - `ModifiedBetween`: Tuple of each argument's modified between state (true if to be automatically derived). """ -struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI, ErrIfFuncWritten} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end -const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false}() -const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false}() -@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, ErrIfFuncWritten}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, ErrIfFuncWritten}() -@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, ErrIfFuncWritten}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, ErrIfFuncWritten}() +struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,RuntimeActivity,ModifiedBetween,ABI,Holomorphic,ErrIfFuncWritten,ShadowInit} <: Mode{ABI, ErrIfFuncWritten,RuntimeActivity} end +const ReverseSplitNoPrimal = ReverseModeSplit{false, true, false, 0, true,DefaultABI, false, false, false}() +const ReverseSplitWithPrimal = ReverseModeSplit{true, true, false, 0, true,DefaultABI, false, false, false}() +@inline ReverseSplitModified(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, Width, MBO, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{MB}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,MBO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity, Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline ReverseSplitWidth(::ReverseModeSplit{ReturnPrimal, ReturnShadow, RuntimeActivity, WidthO, MB, ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, ::Val{Width}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,WidthO,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,MB,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() -@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, true}() -@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, false}() +@inline set_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, true, ShadowInit}() +@inline clear_err_if_func_written(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI, Holomorphic, false, ShadowInit}() -@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,true,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline set_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}, rt::Bool) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,rt,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline clear_runtime_activity(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{ReturnPrimal,ReturnShadow,false,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() -@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() -@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, ErrIfFuncWritten}() +@inline WithPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{true,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() +@inline NoPrimal(::ReverseModeSplit{ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}) where {ReturnPrimal,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit} = ReverseModeSplit{false,ReturnShadow,RuntimeActivity,Width,ModifiedBetween,ABI, Holomorphic, ErrIfFuncWritten, ShadowInit}() """ diff --git a/src/Enzyme.jl b/src/Enzyme.jl index b49c3738f6..5224d98bf6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -349,18 +349,13 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Nargs + 1)) + ModifiedBetweenT = falses_from_args(Nargs + 1) + ModifiedBetween = Val(ModifiedBetweenT) tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} FTy = Core.Typeof(f.val) - opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) - else - Val(codegen_world_age(FTy, tt)) - end - rt = if A isa UnionAll Compiler.primal_return_type(rmode, Val(codegen_world_age(FTy, tt)), FTy, tt) else @@ -369,20 +364,22 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) if A <: Active if (!allocatedinline(rt) || rt isa Union) && rt != Union{} - forward, adjoint = Enzyme.Compiler.thunk( - opt_mi, + forward, adjoint = autodiff_thunk( + ReverseModeSplit{ + ReturnPrimal, + #=ReturnShadow=#false, + RuntimeActivity, + width, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#true + }(), FA, Duplicated{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(width), - ModifiedBetween, - Val(ReturnPrimal), - Val(true), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + (tt′).parameters... + ) res = forward(f, args...) tape = res[1] if ReturnPrimal @@ -399,6 +396,12 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Duplicated Returns not yet handled")) end + opt_mi = if RABI <: NonGenABI + Compiler.fspec(eltype(FA), tt′) + else + Val(codegen_world_age(FTy, tt)) + end + if (A <: Active && rt <: Complex) && rt != Union{} if Holomorphic seen = IdDict() @@ -650,7 +653,7 @@ Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, + rmode::ReverseMode{ReturnPrimal,RuntimeActivity,RABI,Holomorphic,ErrIfFuncWritten}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -659,7 +662,7 @@ code, as well as high-order differentiation. A<:Annotation, ReturnPrimal, Nargs, - ABI, + RABI<:ABI, Holomorphic, ErrIfFuncWritten, RuntimeActivity, @@ -671,27 +674,85 @@ code, as well as high-order differentiation. end tt = Tuple{map(T -> eltype(Core.Typeof(T)), args)...} - world = codegen_world_age(Core.Typeof(f.val), tt) + FTy = Core.Typeof(f.val) + world = codegen_world_age(FTy, tt) + + A2 = A if A isa UnionAll - rt = Core.Compiler.return_type(f.val, tt) - rt = A{rt} + rt = Compiler.primal_return_type(rmode, Val(world), FTy, tt) + A2 = A{rt} else @assert A isa DataType rt = A end - if eltype(rt) == Union{} + if rt == Union{} error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Nargs + 1)) + ModifiedBetweenT = falses_from_args(Nargs + 1) + ModifiedBetween = Val(ModifiedBetweenT) + + if A <: Active + if (!allocatedinline(rt) || rt isa Union) && rt != Union{} + rs = ReverseModeSplit{ + ReturnPrimal, + #=ReturnShadow=#false, + RuntimeActivity, + width, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#true + }() + TapeType = tape_type(rs, FA, Duplicated{rt}, + (tt′).parameters...) + forward, adjoint = autodiff_deferred_thunk( + rs, + TapeType, + FA, + Duplicated{rt}, + (tt′).parameters... + ) + res = forward(f, args...) + tape = res[1] + if ReturnPrimal + return (adjoint(f, args..., tape)[1], res[2]) + else + return adjoint(f, args..., tape) + end + end + elseif A <: Duplicated || + A <: DuplicatedNoNeed || + A <: BatchDuplicated || + A <: BatchDuplicatedNoNeed || + A <: BatchDuplicatedFunc + throw(ErrorException("Duplicated Returns not yet handled")) + end + + if (A <: Active && rt <: Complex) && rt != Union{} + if Holomorphic + throw( + ErrorException( + "Reverse-mode Active Holomorphic is not yet implemented in deferred codegen", + ), + ) + end + + throw( + ErrorException( + "Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/faq/#Complex-numbers for more details.", + ), + ) + end adjoint_ptr = Compiler.deferred_codegen( Val(world), FA, Val(tt′), - Val(rt), + Val(A), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, @@ -703,9 +764,9 @@ code, as well as high-order differentiation. ) #=ShadowInit=# thunk = - Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,rt,tt′,width,ReturnPrimal}(adjoint_ptr) - if rt <: Active - args = (args..., Compiler.default_adjoint(eltype(rt))) + Compiler.CombinedAdjointThunk{Ptr{Cvoid},FA,A2,tt′,width,ReturnPrimal}(adjoint_ptr) + if A <: Active + args = (args..., Compiler.default_adjoint(rt)) elseif A <: Duplicated || A <: DuplicatedNoNeed || A <: BatchDuplicated || @@ -722,7 +783,7 @@ Same as `autodiff(::ForwardMode, f, Activity, args...)` but uses deferred compil code, as well as high-order differentiation. """ @inline function autodiff_deferred( - ::ForwardMode{ReturnPrimal,ABI,ErrIfFuncWritten,RuntimeActivity}, + ::ForwardMode{ReturnPrimal,RABI,ErrIfFuncWritten,RuntimeActivity}, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}, @@ -731,7 +792,7 @@ code, as well as high-order differentiation. FA<:Annotation, A<:Annotation, Nargs, - ABI, + RABI<:ABI, ErrIfFuncWritten, RuntimeActivity, } @@ -856,7 +917,9 @@ result, ∂v, ∂A Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit }, ::Type{FA}, ::Type{A}, @@ -871,6 +934,7 @@ result, ∂v, ∂A RABI<:ABI, Nargs, ErrIfFuncWritten, + ShadowInit, RuntimeActivity, } width = if Width == 0 @@ -891,9 +955,6 @@ result, ∂v, ∂A tt = Tuple{map(eltype, args)...} - if !(A <: Const) - @assert ReturnShadow - end tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI Compiler.fspec(eltype(FA), tt′) @@ -909,7 +970,7 @@ result, ∂v, ∂A Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -1054,7 +1115,9 @@ end Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit, }, ::Type{FA}, ::Type{A}, @@ -1070,6 +1133,7 @@ end Nargs, ErrIfFuncWritten, RuntimeActivity, + ShadowInit, } width = if Width == 0 w = same_or_one(1, args...) @@ -1087,7 +1151,6 @@ end ModifiedBetween = Val(ModifiedBetweenT) end - @assert ReturnShadow TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} @@ -1105,7 +1168,7 @@ end Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), RABI, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -1133,6 +1196,9 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, + #=ErrIfFuncWritten=#false, + #=ShadowInit=#false, }, ::Type{FA}, ::Type{A}, @@ -1214,7 +1280,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Type{<:Annotation}...) + autodiff_deferred_thunk(::ReverseModeSplit, TapeType::Type, ftype::Type{<:Annotation}, Activity::Type{<:Annotation}, argtypes::Type{<:Annotation}...) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -1265,7 +1331,9 @@ result, ∂v, ∂A Width, ModifiedBetweenT, RABI, + #=Holomorphic=#false, ErrIfFuncWritten, + ShadowInit, }, tt::Type{TapeType}, fa::Type{FA}, @@ -1283,6 +1351,7 @@ result, ∂v, ∂A Nargs, ErrIfFuncWritten, RuntimeActivity, + ShadowInit } @assert RABI == FFIABI width = if Width == 0 @@ -1301,7 +1370,6 @@ result, ∂v, ∂A ModifiedBetween = Val(ModifiedBetweenT) end - @assert ReturnShadow TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} @@ -1316,7 +1384,7 @@ result, ∂v, ∂A Val(width), ModifiedBetween, Val(ReturnPrimal), - Val(false), + Val(ShadowInit), TapeType, Val(ErrIfFuncWritten), Val(RuntimeActivity), @@ -2053,7 +2121,6 @@ this function will retun an AbstractArray of shape `size(output)` of values of t jac end else - @assert !Holomorphic n_out_val = if length(Compiler.element(n_outs)) == 0 0 else @@ -2073,32 +2140,27 @@ this function will retun an AbstractArray of shape `size(output)` of values of t Core.Compiler.return_type(f, tt) end - ModifiedBetween = Val((false, false)) + ModifiedBetweenT = (false, false) FRT = Core.Typeof(f) FA = Const{FRT} - opt_mi = if RABI <: NonGenABI - Compiler.fspec(FRT, tt′) - else - Val(codegen_world_age(FRT, tt)) - end - if chunk == Val(1) || chunk == nothing - tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}} - primal, adjoint = Enzyme.Compiler.thunk( - opt_mi, + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + #=width=#1, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, DuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(1), - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + MD ? MixedDuplicated{XT} : Duplicated{XT} + ) tmp = ntuple(Val(n_out_val)) do i Base.@_inline_meta z = make_zero(x) @@ -2114,23 +2176,22 @@ this function will retun an AbstractArray of shape `size(output)` of values of t rows, outshape else chunksize = Compiler.element(chunk) - tt′ = - MD ? Tuple{BatchMixedDuplicated{XT,chunksize}} : - Tuple{BatchDuplicated{XT,chunksize}} - primal, adjoint = Enzyme.Compiler.thunk( - opt_mi, + primal, adjoint = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + chunksize, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, - BatchDuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - chunk, - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + BatchDuplicatedNoNeed{rt, chunksize}, + MD ? BatchMixedDuplicated{XT, chunksize} : BatchDuplicated{XT, chunksize} + ) num = ((n_out_val + chunksize - 1) ÷ chunksize) @@ -2140,20 +2201,22 @@ this function will retun an AbstractArray of shape `size(output)` of values of t else last_size = n_out_val - (num - 1) * chunksize tt′ = Tuple{BatchDuplicated{Core.Typeof(x),last_size}} - primal2, adjoint2 = Enzyme.Compiler.thunk( - opt_mi, + primal2, adjoint2 = autodiff_thunk( + ReverseModeSplit{ + #=ReturnPrimal=#false, + #=ReturnShadow=#true, + RuntimeActivity, + last_size, + ModifiedBetweenT, + RABI, + Holomorphic, + ErrIfFuncWritten, + #=ShadowInit=#false + }(), FA, - BatchDuplicatedNoNeed{rt}, - tt′, - Val(API.DEM_ReverseModeGradient), - Val(last_size), - ModifiedBetween, - Val(false), - Val(false), - RABI, - Val(ErrIfFuncWritten), - Val(RuntimeActivity), - ) #=ShadowInit=# + BatchDuplicatedNoNeed{rt, last_size}, + MD ? BatchMixedDuplicated{XT, last_size} : BatchDuplicated{XT, last_size} + ) end tmp = ntuple(num) do i diff --git a/test/abi.jl b/test/abi.jl index cbd467c155..5acb30e04f 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -300,6 +300,20 @@ using Test # returns: sret, const/ghost, !deserve_retbox end +unstable_load(x) = Base.inferencebarrier(x)[1] + +@testset "Any Return" begin + x = [2.7] + dx = [0.0] + Enzyme.autodiff(Reverse, Const(unstable_load), Active, Duplicated(x, dx)) + @test dx ≈ [1.0] + + x = [2.7] + dx = [0.0] + Enzyme.autodiff_deferred(Reverse, Const(unstable_load), Active, Duplicated(x, dx)) + @test dx ≈ [1.0] +end + @testset "Mutable Struct ABI" begin mutable struct MStruct val::Float32