Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
ReverseDiff with chainrules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent a57e936 commit 1686504
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ env:
RETESTITEMS_NWORKER_THREADS: 2
JULIA_AMDGPU_LOGGING_ENABLED: true
RETESTITEMS_TESTITEM_TIMEOUT: 10000
SECRET_CODECOV_TOKEN: "NkRNr3aGg3k4bKi07RGhGdhhsPV8t97y0VASfmra5BzT+6h/S+hEZ6p7U6SE0/1LQrHxRBy9vaWiwF+VW1ZHk7KMUetuOYmymXON/AUBbiE4LsfFOVwrna7U0kuqWHZbdKn8XAJxu6au1uRMrOXPXw176KkuWRzwF/jLWvvv7s+KqX4oaiDirXxGCRSssVizT2hdWkkrtct+GjLeF/g9jgGa8xn8j2Pp8AS62EPMoC/YKgV/e3yK58LSOKOBF+1ddvYzaFoDABkNMehHA52MNXgDoxikTc0YGnd8nMGfTUiRPaNLHRQaXS/M0oaVT7PkXFlJe6O6izCnkIx2+Ix57w==;U2FsdGVkX1/k+16T0rj/Tntc6gOaH8GwRDvncs1a+BwbnnnnXeAmiwvbowfRSnoldKtpHhJcwLQLbXFDXD8U6g=="
SECRET_CODECOV_TOKEN: "zLFPthE27DkLNSAv2AWwzWtPIyEFzhsPtiDMN9NNm34ZPUVMeGn1dDhZwzpMCnQs0GbwUUYFkPiZhp52xhaxCWIrgy1vazeuiqZxxoDlkBPlrwe9afa3HpFDoFG2CAAv8UZtWM7U5XpKyUCFQX9iQ89RgkXpU4bV7U0342PEqBl7zG/mVkBWbkJA0Tf7HWTCdxJ2YbNHuMnErahMLL2u7vKRrN+jwzhuYbHU3bWNqgyh+DI3AONhUy+2ClKb3JKJYBlYpwcdPuF2M0dV7Rgd7MuNXFZ1uiuPOSRLjuGU44c1OU67GDye4HkVNgWZaOhw5ccSnTD2WBBrklnXc9Uy1w==;U2FsdGVkX19q4a438CUZNXPYDrkhPFYW7x8VZRVePbU9l0hswT4iZjZyJNxxVryjgDm89v3wFNBaBpa3dEabaw=="
43 changes: 40 additions & 3 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ function BatchedRoutines.batched_jacobian(
return BatchedRoutines.batched_jacobian(ad, f, ArrayInterface.aos_to_soa(x))
end

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray)

function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
Expand All @@ -37,8 +40,10 @@ function BatchedRoutines.batched_jacobian(
return BatchedRoutines.batched_jacobian(ad, f, x, ArrayInterface.aos_to_soa(p))
end

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray)
function BatchedRoutines.batched_jacobian(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray) where {F}
return BatchedRoutines.batched_jacobian(ad, f, ArrayInterface.aos_to_soa(x), p)
end

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray)
Expand All @@ -49,7 +54,39 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_jacobian(
ad, f, x::ReverseDiff.TrackedArray, p)

# TODO: Fix the gradient call over ReverseDiff
function BatchedRoutines.batched_gradient(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
return BatchedRoutines.batched_gradient(ad, f, ArrayInterface.aos_to_soa(x))
end

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x::ReverseDiff.TrackedArray)

function BatchedRoutines.batched_gradient(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
return BatchedRoutines.batched_gradient(
ad, f, ArrayInterface.aos_to_soa(x), ArrayInterface.aos_to_soa(p))
end

function BatchedRoutines.batched_gradient(
ad, f::F, x::AbstractArray, p::AbstractArray{<:ReverseDiff.TrackedReal}) where {F}
return BatchedRoutines.batched_gradient(ad, f, x, ArrayInterface.aos_to_soa(p))
end

function BatchedRoutines.batched_gradient(
ad, f::F, x::AbstractArray{<:ReverseDiff.TrackedReal}, p::AbstractArray) where {F}
return BatchedRoutines.batched_gradient(ad, f, ArrayInterface.aos_to_soa(x), p)
end

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x, p::ReverseDiff.TrackedArray)

ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x::ReverseDiff.TrackedArray, p)

@concrete struct ReverseDiffPullbackFunction <: Function
tape
Expand Down

0 comments on commit 1686504

Please sign in to comment.