Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Efficient gradient of jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 13, 2024
1 parent 978a6c3 commit 1b4d36b
Showing 1 changed file with 44 additions and 33 deletions.
77 changes: 44 additions & 33 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,80 @@ function __batched_value_and_jacobian(ad, f::F, x) where {F}
return f(x), J
end

# FIXME: Gradient of jacobians is really in-efficient here
# TODO: Use OneElement for this
function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) where {F}
N, B = size(x)
J, H = __batched_value_and_jacobian(
ad, @closure(y->reshape(batched_jacobian(ad, f, y).data, :, B)), x)
if !_is_extension_loaded(Val(:ForwardDiff)) || !_is_extension_loaded(Val(:Zygote))
throw(ArgumentError("`ForwardDiff.jl` and `Zygote.jl` needs to be loaded to \
compute the gradient of `batched_jacobian`."))
end

J = batched_jacobian(ad, f, x)

function ∇batched_jacobian(Δ)
∂x = reshape(
batched_mul(reshape.data, 1, :, nbatches(Δ)), H.data), :, nbatches(Δ))
∇batched_jacobian = Δ -> begin
gradient_ad = AutoZygote()
_map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_, p))[i:i]), x),
x, reshape(Δᵢ, size(x)))
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end

return UniformBlockDiagonalMatrix(reshape(J, :, N, B)), ∇batched_jacobian
return J, ∇batched_jacobian
end

function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F}
N, B = size(x)
J, H = __batched_value_and_jacobian(
ad, @closure(y->reshape(batched_jacobian(ad, f, y, p).data, :, B)), x)

p_size = size(p)
_, Jₚ_ = __batched_value_and_jacobian(
ad, @closure(p->reshape(batched_jacobian(ad, f, x, reshape(p, p_size)).data, :, B)),
vec(p))
Jₚ = dropdims(Jₚ_.data; dims=3)

function ∇batched_jacobian(Δ)
∂x = reshape(
batched_mul(reshape.data, 1, :, nbatches(Δ)), H.data), :, nbatches(Δ))
∂p = reshape(reshape.data, 1, :) * Jₚ, p_size)
if !_is_extension_loaded(Val(:ForwardDiff)) || !_is_extension_loaded(Val(:Zygote))
throw(ArgumentError("`ForwardDiff.jl` and `Zygote.jl` needs to be loaded to \
compute the gradient of `batched_jacobian`."))
end

J = batched_jacobian(ad, f, x, p)

∇batched_jacobian = Δ -> begin
_map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x),
x, reshape(Δᵢ, size(x)))

∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x))

_map_fnₚ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
(x, p_) -> batched_gradient(AutoZygote(), p__ -> sum(vec(f(x, p__))[i:i]), p_),
x, reshape(Δᵢ, size(x)), p)

∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(eachrow(Δ))), size(p))

return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end

return UniformBlockDiagonalMatrix(reshape(J, :, N, B)), ∇batched_jacobian
return J, ∇batched_jacobian
end

function CRC.rrule(::typeof(batched_gradient), ad, f::F, x) where {F}
BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) ||
_is_extension_loaded(Val(:ForwardDiff)) ||
throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \
of `batched_gradient`."))

dx = BatchedRoutines.batched_gradient(ad, f, x)
dx = batched_gradient(ad, f, x)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(
AutoForwardDiff(), @closure(x->BatchedRoutines.batched_gradient(ad, f, x)),
AutoForwardDiff(), @closure(x->batched_gradient(ad, f, x)),
x, reshape(Δ, size(x)))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return dx, ∇batched_gradient
end

function CRC.rrule(::typeof(batched_gradient), ad, f::F, x, p) where {F}
BatchedRoutines._is_extension_loaded(Val(:ForwardDiff)) ||
_is_extension_loaded(Val(:ForwardDiff)) ||
throw(ArgumentError("`ForwardDiff.jl` needs to be loaded to compute the gradient \
of `batched_gradient`."))

dx = BatchedRoutines.batched_gradient(ad, f, x, p)
dx = batched_gradient(ad, f, x, p)
∇batched_gradient = @closure Δ -> begin
∂x = _jacobian_vector_product(AutoForwardDiff(),
@closure(x->BatchedRoutines.batched_gradient(ad, Base.Fix2(f, p), x)),
∂x = _jacobian_vector_product(
AutoForwardDiff(), @closure(x->batched_gradient(ad, Base.Fix2(f, p), x)),
x, reshape(Δ, size(x)))
∂p = _jacobian_vector_product(AutoForwardDiff(),
@closure((x, p)->BatchedRoutines.batched_gradient(ad, Base.Fix1(f, x), p)),
∂p = _jacobian_vector_product(
AutoForwardDiff(), @closure((x, p)->batched_gradient(ad, Base.Fix1(f, x), p)),
x, reshape(Δ, size(x)), p)
return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end
Expand Down

0 comments on commit 1b4d36b

Please sign in to comment.