diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4062b67..224ec0e 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -29,4 +29,4 @@ env: RETESTITEMS_NWORKER_THREADS: 2 JULIA_AMDGPU_LOGGING_ENABLED: true RETESTITEMS_TESTITEM_TIMEOUT: 10000 - SECRET_CODECOV_TOKEN: "NkRNr3aGg3k4bKi07RGhGdhhsPV8t97y0VASfmra5BzT+6h/S+hEZ6p7U6SE0/1LQrHxRBy9vaWiwF+VW1ZHk7KMUetuOYmymXON/AUBbiE4LsfFOVwrna7U0kuqWHZbdKn8XAJxu6au1uRMrOXPXw176KkuWRzwF/jLWvvv7s+KqX4oaiDirXxGCRSssVizT2hdWkkrtct+GjLeF/g9jgGa8xn8j2Pp8AS62EPMoC/YKgV/e3yK58LSOKOBF+1ddvYzaFoDABkNMehHA52MNXgDoxikTc0YGnd8nMGfTUiRPaNLHRQaXS/M0oaVT7PkXFlJe6O6izCnkIx2+Ix57w==;U2FsdGVkX1/k+16T0rj/Tntc6gOaH8GwRDvncs1a+BwbnnnnXeAmiwvbowfRSnoldKtpHhJcwLQLbXFDXD8U6g==" + SECRET_CODECOV_TOKEN: "zLFPthE27DkLNSAv2AWwzWtPIyEFzhsPtiDMN9NNm34ZPUVMeGn1dDhZwzpMCnQs0GbwUUYFkPiZhp52xhaxCWIrgy1vazeuiqZxxoDlkBPlrwe9afa3HpFDoFG2CAAv8UZtWM7U5XpKyUCFQX9iQ89RgkXpU4bV7U0342PEqBl7zG/mVkBWbkJA0Tf7HWTCdxJ2YbNHuMnErahMLL2u7vKRrN+jwzhuYbHU3bWNqgyh+DI3AONhUy+2ClKb3JKJYBlYpwcdPuF2M0dV7Rgd7MuNXFZ1uiuPOSRLjuGU44c1OU67GDye4HkVNgWZaOhw5ccSnTD2WBBrklnXc9Uy1w==;U2FsdGVkX19q4a438CUZNXPYDrkhPFYW7x8VZRVePbU9l0hswT4iZjZyJNxxVryjgDm89v3wFNBaBpa3dEabaw==" diff --git a/ext/BatchedRoutinesReverseDiffExt.jl b/ext/BatchedRoutinesReverseDiffExt.jl index ad8de30..fae9a5d 100644 --- a/ext/BatchedRoutinesReverseDiffExt.jl +++ b/ext/BatchedRoutinesReverseDiffExt.jl @@ -25,6 +25,9 @@ function BatchedRoutines.batched_jacobian( return BatchedRoutines.batched_jacobian(ad, f, ArrayInterface.aos_to_soa(x)) end +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( + ad, f, x::ReverseDiff.TrackedArray) + function BatchedRoutines.batched_jacobian( ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} @@ -37,8 +40,10 @@ function BatchedRoutines.batched_jacobian( return BatchedRoutines.batched_jacobian(ad, f, x, ArrayInterface.aos_to_soa(p)) end -ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( - ad, f, x::ReverseDiff.TrackedArray) +function BatchedRoutines.batched_jacobian( + ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray) where {F} + return BatchedRoutines.batched_jacobian(ad, f, ArrayInterface.aos_to_soa(x), p) +end ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray) @@ -49,7 +54,39 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian( ad, f, x::ReverseDiff.TrackedArray, p) -# TODO: Fix the gradient call over ReverseDiff +function BatchedRoutines.batched_gradient( + ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} + return BatchedRoutines.batched_gradient(ad, f, ArrayInterface.aos_to_soa(x)) +end + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( + ad, f, x::ReverseDiff.TrackedArray) + +function BatchedRoutines.batched_gradient( + ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} + return BatchedRoutines.batched_gradient( + ad, f, ArrayInterface.aos_to_soa(x), ArrayInterface.aos_to_soa(p)) +end + +function BatchedRoutines.batched_gradient( + ad, f::F, x::AbstractArray, p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F} + return BatchedRoutines.batched_gradient(ad, f, x, ArrayInterface.aos_to_soa(p)) +end + +function BatchedRoutines.batched_gradient( + ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray) where {F} + return BatchedRoutines.batched_gradient(ad, f, ArrayInterface.aos_to_soa(x), p) +end + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( + ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray) + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( + ad, f, x, p::ReverseDiff.TrackedArray) + +ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient( + ad, f, x::ReverseDiff.TrackedArray, p) @concrete struct ReverseDiffPullbackFunction <: Function tape