Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further ChainRulesCore.rrule Integration #254

Open
wants to merge 44 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cef527f
Bump patch version
willtebbutt Sep 12, 2024
b9c3f65
Fix usage with benchmarktools
willtebbutt Sep 12, 2024
1f49d85
Merge branch 'main' into wct/actually-improve-rrule-integration
willtebbutt Sep 13, 2024
8f0f75d
Initial pass
willtebbutt Sep 13, 2024
e791cef
Bump patch
willtebbutt Sep 13, 2024
f45456e
Unit test to_tapir_tangent and to_cr_tangent
willtebbutt Sep 13, 2024
bec9f06
Make use of macro
willtebbutt Sep 13, 2024
d037101
More testing and tidying up
willtebbutt Sep 13, 2024
54947f0
Add some basic type checking and a test
willtebbutt Sep 13, 2024
bc88483
Improve formatting and commenting
willtebbutt Sep 13, 2024
f29b8f3
Formatting
willtebbutt Sep 13, 2024
50d7dd8
Improve documentation
willtebbutt Sep 13, 2024
1788c07
Explain how not to use rrule functionality
willtebbutt Sep 13, 2024
b4e80bc
Add rules for BLAS utilities
willtebbutt Sep 13, 2024
4a2b8e0
Initial NNlib integration
willtebbutt Sep 13, 2024
d1d9fae
Thunks and batched_mul
willtebbutt Sep 13, 2024
6f036ad
More rules + kwargs + rename
willtebbutt Sep 13, 2024
e225a0a
Fix link in docs
willtebbutt Sep 13, 2024
3bba38e
Rename chain_rules_macro to chain_rules_interop
willtebbutt Sep 13, 2024
619f0ce
Complete rename of chain rules interop file
willtebbutt Sep 16, 2024
345c46a
Refactor chain rules interop
willtebbutt Sep 16, 2024
8e87d11
Add more nnlib functionality
willtebbutt Sep 16, 2024
d345978
Remove old tests
willtebbutt Sep 16, 2024
0f3fe90
Some work
willtebbutt Sep 16, 2024
ae93a27
Remove errant show statment
willtebbutt Sep 17, 2024
82ecd82
Remove redundant test
willtebbutt Sep 17, 2024
ca93535
Support where
willtebbutt Sep 17, 2024
fc6c00f
Make use of where params
willtebbutt Sep 17, 2024
473bc02
Improve kwarg interface
willtebbutt Sep 17, 2024
1cfbfcc
Default kwargs test
willtebbutt Sep 17, 2024
8ac2903
Improve docstring
willtebbutt Sep 17, 2024
f60ca36
Merge in main
willtebbutt Sep 19, 2024
ce5afd9
Some work
willtebbutt Sep 25, 2024
3539d46
Merge in main
willtebbutt Sep 29, 2024
8a80218
Merge branch 'main' into wct/actually-improve-rrule-integration
willtebbutt Sep 29, 2024
ccdef0b
Merge branch 'main' into wct/actually-improve-rrule-integration
willtebbutt Sep 29, 2024
6edc9a4
Some work
willtebbutt Sep 30, 2024
f66cc9c
Better conv support in nnlib rules
willtebbutt Oct 1, 2024
f865fde
More LuxLib rules
willtebbutt Oct 1, 2024
149e7b4
Permit :meta nodes in IR
willtebbutt Oct 1, 2024
2dcd535
Remove redundant test
willtebbutt Oct 1, 2024
0933f37
Uncomment some tests
willtebbutt Oct 1, 2024
d217102
Rename chain rules doc
willtebbutt Oct 1, 2024
c6f8cf0
Add notes to docs on rule writing strategies
willtebbutt Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
MooncakeCUDAExt = "CUDA"
MooncakeDynamicPPLExt = "DynamicPPL"
MooncakeJETExt = "JET"
MooncakeLogDensityProblemsADExt = "LogDensityProblemsAD"
MooncakeLuxLibExt = "LuxLib"
MooncakeNNlibExt = "NNlib"
MooncakeSpecialFunctionsExt = "SpecialFunctions"

[compat]
Expand All @@ -46,7 +50,9 @@ FillArrays = "1"
Graphs = "1"
JET = "0.9"
LogDensityProblemsAD = "1"
LuxLib = "1.2"
MistyClosures = "1"
NNlib = "0.9"
PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
Expand All @@ -66,11 +72,14 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
7 changes: 5 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ makedocs(
"Algorithmic Differentiation" => "algorithmic_differentiation.md",
"Mooncake.jl's Rule System" => "mathematical_interpretation.md",
],
"Utilities" => [
"Tools for Rules" => "tools_for_rules.md",
"Debug Mode" => "debug_mode.md",
"Debugging and MWEs" => "debugging_and_mwes.md",
],
"Known Limitations" => "known_limitations.md",
"Debug Mode" => "debug_mode.md",
"Debugging and MWEs" => "debugging_and_mwes.md",
]
)

Expand Down
50 changes: 50 additions & 0 deletions docs/src/tools_for_rules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Tools for Rules

Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported.
However, this does not always necessitate writing your own `rrule!!` from scratch.
In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations.

## Simplfiying Code via Overlays

Suppose you have a function
```julia
foo(x::Float64) = bar(x)
```
where Mooncake.jl fails to differentiate `bar` for some reason.
If you have access to another function `baz`, which does the same thing as `bar`, but does so in a way which Mooncake.jl can differentiate, you can simply write:
```julia
Base.Experimental.@overlay Mooncake.mooncake_method_table foo(x::Float64) = baz(x)
```
When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than the original, and should successfully differentiate it.
If you search for `@overlay` in the Mooncake.jl source code, you will see a variety of instances where this is used in practice.

This approach is often very straightforward, and we recommend you try this first before going down the path of writing rules.

## Functions with Zero Derivative

If the above strategy does not work, but you find yourself in the surprisingly common situation that the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:
```@docs
Mooncake.simple_zero_adjoint
```
Suppose you have a function `foo(x, y, z)` whose derivative is zero, you would write an `rrule!!` as follows:
```julia
function Mooncake.rrule!!(f::CoDual{typeof(foo)}, x::CoDual, y::CoDual, z::CoDual)
return Mooncake.simple_zero_adjoint(f, x, y, z)
end
```
Users of ChainRules.jl should be familiar with this functionality -- it is morally the same as `ChainRulesCore.@non_differentiable`.
This approach is utilised often in Mooncake.jl's codebase.

## Using ChainRules.jl

[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode.
These rules are methods of the `ChainRulesCore.rrule` function.
There are some instances where there is it most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.
The docstrings below explain this functionality, and how it should / should not be used.

```@docs
Mooncake.@from_rrule
Mooncake.rrule_wrapper
```
176 changes: 176 additions & 0 deletions ext/MooncakeLuxLibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
module MooncakeLuxLibExt

using LuxLib, Random, Mooncake
using Base: IEEEFloat
using Base.Experimental: @overlay

import LuxLib: Impl
import LuxLib.Utils: static_training_mode_check
import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table, CoDual

@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat})
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat},
)
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat},
)

# Re-implement a bunch of methods to ensure that Mooncake can differentiate them.
@overlay mooncake_method_table function LuxLib.Impl.fused_dense(

Check warning on line 22 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L22

Added line #L22 was not covered by tests
opmode,
act::F,
weight::AbstractMatrix,
x::AbstractMatrix,
b::LuxLib.Optional{<:AbstractVector},
) where {F}
return bias_activation(act, Impl.matmul(weight, x), b)
end

@overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!(

Check warning on line 32 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L32

Added line #L32 was not covered by tests
y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector
) where {F, xT, yT}
return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias)

Check warning on line 35 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L35

Added line #L35 was not covered by tests
end

@overlay mooncake_method_table function LuxLib.Impl.activation_loop!(

Check warning on line 38 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L38

Added line #L38 was not covered by tests
y::AbstractArray, σ::F, x::AbstractArray
) where {F}
return LuxLib.Impl.activation_simd_loop!(y, σ, x)
end

@overlay mooncake_method_table function LuxLib.Impl.fused_conv(

Check warning on line 44 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L44

Added line #L44 was not covered by tests
::LuxLib.Impl.AbstractInternalArrayOpMode,
act::F,
weight::AbstractArray{wT, N},
x::AbstractArray{xT, N},
bias::LuxLib.Optional{<:AbstractVector},
cdims::LuxLib.Impl.ConvDims,
) where {F, wT, xT, N}
return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias)

Check warning on line 52 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L52

Added line #L52 was not covered by tests
end

for f in [
Impl.SLEEFActivations.sigmoid_fast,
Impl.SLEEFActivations.softplus,
Impl.SLEEFActivations.logsigmoid,
Impl.SLEEFActivations.swish,
Impl.SLEEFActivations.lisht,
Impl.SLEEFActivations.tanh,
Impl.SLEEFActivations.tanh_fast,
]
@from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat}
@from_rrule(
DefaultCtx,
Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}},
)
end

Mooncake.@is_primitive(DefaultCtx, Tuple{typeof(static_training_mode_check), Vararg})
function Mooncake.rrule!!(f::CoDual{typeof(static_training_mode_check)}, x::CoDual...)
return Mooncake.simple_zero_adjoint(f, x...)
end




# This is a really horrible hack that we need to do until Mooncake is able to support the
# call-back-into-ad interface that ChainRules exposes.

import LuxLib.Impl:
safe_eltype,
batchnorm_affine_normalize_internal,
batchnorm_affine_normalize_internal!,
∇batchnorm_affine_normalize,
AbstractInternalArrayOpMode

import ChainRulesCore as CRC

function CRC.rrule(
::typeof(batchnorm_affine_normalize_internal),
opmode::AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{T, N},
μ::AbstractVector,
σ²::AbstractVector,
γ::LuxLib.Optional{<:AbstractVector},
β::LuxLib.Optional{<:AbstractVector},
ϵ::Real,
) where {T, N}
y = similar(
x,
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
)
γ′ = similar(
x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1)
)

batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′)

𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²)
𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ)
𝒫β = β === nothing ? identity : CRC.ProjectTo(β)

∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin
∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′)
∂∅ = CRC.NoTangent()
return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅
end

return y, ∇batchnorm_affine_normalize_internal
end

@from_rrule(
DefaultCtx,
Tuple{
typeof(batchnorm_affine_normalize_internal),
AbstractInternalArrayOpMode,
typeof(identity),
AbstractArray,
AbstractVector,
AbstractVector,
LuxLib.Optional{<:AbstractVector},
LuxLib.Optional{<:AbstractVector},
Real,
},
)

@overlay mooncake_method_table function batchnorm_affine_normalize_internal(

Check warning on line 142 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L142

Added line #L142 was not covered by tests
opmode::LuxLib.AbstractInternalArrayOpMode,
act::F,
x::AbstractArray{xT, 3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
ϵ::Real,
) where {F, xT}
y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ)
LuxLib.Impl.activation!(y, opmode, act, y)
return y

Check warning on line 154 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L152-L154

Added lines #L152 - L154 were not covered by tests
end

@overlay mooncake_method_table function batchnorm_affine_normalize_internal(

Check warning on line 157 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L157

Added line #L157 was not covered by tests
opmode::LuxLib.AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{xT, 3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
ϵ::Real,
) where {xT}
y = similar(x,

Check warning on line 167 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L167

Added line #L167 was not covered by tests
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
)
batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ)
return y

Check warning on line 173 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L172-L173

Added lines #L172 - L173 were not covered by tests
end

end
65 changes: 65 additions & 0 deletions ext/MooncakeNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module MooncakeNNlibExt

using NNlib, Random, Mooncake
using Base: IEEEFloat
using NNlib: dropout

using NNlib: conv, depthwiseconv
import Mooncake: @from_rrule, DefaultCtx, MinimalCtx

@from_rrule(
MinimalCtx,
Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat},
)
@from_rrule(
MinimalCtx,
Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat},
true,
)
@from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true)
@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true)
@from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true)
@from_rrule(
MinimalCtx,
Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}},
)
@from_rrule(
MinimalCtx,
Tuple{
typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims,
},
)
@from_rrule(
MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims}
)
@from_rrule(
MinimalCtx,
Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}},
true,
)
for conv in [:conv, :depthwiseconv]
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])

@eval @from_rrule(
MinimalCtx,
Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat},
true,
)
@eval @from_rrule(
MinimalCtx,
Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat},
true,
)
end
@eval @from_rrule(
MinimalCtx,
Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat},
true,
)
for pool in [:maxpool, :meanpool]
@eval @from_rrule(
MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true
)
end
@from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true)
end
3 changes: 2 additions & 1 deletion src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using
Random,
Setfield

# There are many clashing names, so we will always qualify uses of names from CRC.
import ChainRulesCore

using Base:
Expand Down Expand Up @@ -85,7 +86,7 @@ include(joinpath("rrules", "misc.jl"))
include(joinpath("rrules", "new.jl"))
include(joinpath("rrules", "tasks.jl"))

include("chain_rules_macro.jl")
include("chain_rules_interop.jl")
include("interface.jl")
include("config.jl")

Expand Down
Loading
Loading