Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Remove extra code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 1686504 commit fdec976
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 99 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

[![CI](https://github.com/LuxDL/BatchedRoutines.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/BatchedRoutines.jl/actions/workflows/CI.yml)
[![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu)](https://buildkite.com/julialang/lux-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/BatchedRoutines.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/BatchedRoutines.jl.jl)
[![codecov](https://codecov.io/gh/LuxDL/BatchedRoutines.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/BatchedRoutines.jl)
[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/BatchedRoutines)](https://pkgs.genieframework.com?packages=BatchedRoutines)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

Expand Down
33 changes: 0 additions & 33 deletions ext/BatchedRoutinesReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,37 +88,4 @@ ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ReverseDiff.@grad_from_chainrules BatchedRoutines.batched_gradient(
ad, f, x::ReverseDiff.TrackedArray, p)

@concrete struct ReverseDiffPullbackFunction <: Function
tape
∂input
output
end

function (pb_f::ReverseDiffPullbackFunction)(Δ)
if pb_f.output isa AbstractArray{<:ReverseDiff.TrackedReal}
@inbounds for (oᵢ, Δᵢ) in zip(vec(pb_f.output), vec(Δ))
oᵢ.deriv = Δᵢ
end
else
vec(pb_f.output.deriv) .= vec(Δ)
end
ReverseDiff.reverse_pass!(pb_f.tape)
return pb_f.∂input
end

function BatchedRoutines._value_and_pullback(::AutoReverseDiff, f::F, x) where {F}
tape = ReverseDiff.InstructionTape()
∂x = zero(x)
x_tracked = ReverseDiff.TrackedArray(x, ∂x, tape)
y_tracked = f(x_tracked)

if y_tracked isa AbstractArray{<:ReverseDiff.TrackedReal}
y = ReverseDiff.value.(y_tracked)
else
y = ReverseDiff.value(y_tracked)
end

return y, ReverseDiffPullbackFunction(tape, ∂x, y_tracked)
end

end
2 changes: 0 additions & 2 deletions ext/BatchedRoutinesZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,4 @@ function BatchedRoutines._batched_gradient(::AutoZygote, f::F, u) where {F}
return only(Zygote.gradient(f, u))
end

BatchedRoutines._value_and_pullback(::AutoZygote, f::F, x) where {F} = Zygote.pullback(f, x)

end
1 change: 0 additions & 1 deletion src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ end
# Useful for computing the gradient of a gradient
function _jacobian_vector_product end
function _vector_jacobian_product end
function _value_and_pullback end
function _batched_jacobian end
function _batched_gradient end

Expand Down
86 changes: 24 additions & 62 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,16 @@
return sum(abs2, X_ .* p; dims=1) .- sum(abs, X_ .* p; dims=1) .+ p .^ 2
end

X = randn(rng, 3, 2, 4) |> aType
p = randn(rng, 6) |> aType
Xs = (aType(randn(rng, 3, 2, 4)), aType(randn(rng, 2, 4)), aType(randn(rng, 3)))
ps = (aType(randn(rng, 6)), aType(randn(rng, 2)), aType(randn(rng, 3)))

J_fdiff = batched_jacobian(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p)
for (X, p) in zip(Xs, ps)
J_fdiff = batched_jacobian(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p)

@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3

X = randn(rng, 2, 4) |> aType
p = randn(rng, 2) |> aType

J_fdiff = batched_jacobian(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p)

@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3

X = randn(rng, 3) |> aType
p = randn(rng, 3) |> aType

J_fdiff = batched_jacobian(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_jacobian(AutoForwardDiff(), simple_batched_function, X, p)

@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3
@test Matrix(J_fdiff)Matrix(J_fwdiff) atol=1e-3
end
end
end

Expand All @@ -50,46 +34,24 @@ end
abs2, sum(abs2, X_ .* p; dims=1) .- sum(abs, X_ .* p; dims=1) .+ p .^ 2)
end

X = randn(rng, 3, 2, 4) |> aType
p = randn(rng, 6) |> aType
Xs = (aType(randn(rng, 3, 2, 4)), aType(randn(rng, 2, 4)), aType(randn(rng, 3)))
ps = (aType(randn(rng, 6)), aType(randn(rng, 2)), aType(randn(rng, 3)))

gs_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
gs_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)
for (X, p) in zip(Xs, ps)
gs_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
gs_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)

@test Array(gs_fdiff)Array(gs_fwdiff) atol=1e-3
@test Array(gs_fwdiff)Array(gs_rdiff) atol=1e-3
@test Array(gs_rdiff)Array(gs_zygote) atol=1e-3

X = randn(rng, 2, 4) |> aType
p = randn(rng, 2) |> aType

gs_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
gs_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
gs_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
gs_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)

@test Array(gs_fdiff)Array(gs_fwdiff) atol=1e-3
@test Array(gs_fwdiff)Array(gs_rdiff) atol=1e-3
@test Array(gs_rdiff)Array(gs_zygote) atol=1e-3

X = randn(rng, 3) |> aType
p = randn(rng, 3) |> aType

J_fdiff = batched_gradient(
AutoFiniteDiff(), simple_batched_function, Array(X), Array(p))
J_fwdiff = batched_gradient(AutoForwardDiff(), simple_batched_function, X, p)
J_rdiff = batched_gradient(
AutoReverseDiff(), simple_batched_function, Array(X), Array(p))
J_zygote = batched_gradient(AutoZygote(), simple_batched_function, X, p)
@test Array(gs_fdiff)Array(gs_fwdiff) atol=1e-3
@test Array(gs_fwdiff)Array(gs_rdiff) atol=1e-3
@test Array(gs_rdiff)Array(gs_zygote) atol=1e-3
end

@test Array(J_fdiff)Array(J_fwdiff) atol=1e-3
@test Array(J_fwdiff)Array(J_rdiff) atol=1e-3
@test Array(J_rdiff)Array(J_zygote) atol=1e-3
@testset "Gradient of Gradient" begin
end
end
end

0 comments on commit fdec976

Please sign in to comment.