Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Patch bug
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 13, 2024
1 parent 3b2654e commit 978a6c3
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
4 changes: 3 additions & 1 deletion ext/BatchedRoutinesFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ end

# NOTE: This doesn't exploit batching
@inline function BatchedRoutines._batched_gradient(ad::AutoFiniteDiff, f::F, x) where {F}
return FiniteDiff.finite_difference_batched_gradient(f, x, ad.fdjtype)
return FiniteDiff.finite_difference_gradient(f, x, ad.fdjtype)
end

# TODO: For the gradient call just use FiniteDiff over FiniteDiff

end
8 changes: 4 additions & 4 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x,
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
u_dual = ForwardDiff.Dual{Tag, T, 1}.(u, partials)
y_dual = f(u_dual)
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials)
y_dual = f(x_dual)
return ForwardDiff.partials.(y_dual, 1)
end

Expand All @@ -172,8 +172,8 @@ function BatchedRoutines._jacobian_vector_product(
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
u_dual = ForwardDiff.Dual{Tag, T, 1}.(u, partials)
y_dual = f(u_dual, p)
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials)
y_dual = f(x_dual, p)
return ForwardDiff.partials.(y_dual, 1)
end

Expand Down
2 changes: 2 additions & 0 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ function BatchedRoutines._batched_gradient(::AutoReverseDiff, f::F, u) where {F}
return ReverseDiff.gradient(f, u)
end

# TODO: Fix the gradient call over ReverseDiff

@concrete struct ReverseDiffPullbackFunction <: Function
tape
∂input
Expand Down

0 comments on commit 978a6c3

Please sign in to comment.