diff --git a/Project.toml b/Project.toml index 23eb7957..bb42013a 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] @@ -29,6 +31,8 @@ MooncakeCUDAExt = "CUDA" MooncakeDynamicPPLExt = "DynamicPPL" MooncakeJETExt = "JET" MooncakeLogDensityProblemsADExt = "LogDensityProblemsAD" +MooncakeLuxLibExt = "LuxLib" +MooncakeNNlibExt = "NNlib" MooncakeSpecialFunctionsExt = "SpecialFunctions" [compat] @@ -46,7 +50,9 @@ FillArrays = "1" Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" +LuxLib = "1.2" MistyClosures = "1" +NNlib = "0.9" PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" @@ -66,6 +72,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -73,4 +82,4 @@ TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"] \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index d4d8a17d..f23ec609 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -32,9 +32,12 @@ makedocs( "Algorithmic Differentiation" => "algorithmic_differentiation.md", "Mooncake.jl's Rule System" => "mathematical_interpretation.md", ], + "Utilities" => [ + "Tools for Rules" => "tools_for_rules.md", + "Debug Mode" => "debug_mode.md", + "Debugging and MWEs" => "debugging_and_mwes.md", + ], "Known Limitations" => "known_limitations.md", - "Debug Mode" => "debug_mode.md", - "Debugging and MWEs" => "debugging_and_mwes.md", ] ) diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md new file mode 100644 index 00000000..a3740e99 --- /dev/null +++ b/docs/src/tools_for_rules.md @@ -0,0 +1,50 @@ +# Tools for Rules + +Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. +However, this does not always necessitate writing your own `rrule!!` from scratch. +In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. + +## Simplfiying Code via Overlays + +Suppose you have a function +```julia +foo(x::Float64) = bar(x) +``` +where Mooncake.jl fails to differentiate `bar` for some reason. +If you have access to another function `baz`, which does the same thing as `bar`, but does so in a way which Mooncake.jl can differentiate, you can simply write: +```julia +Base.Experimental.@overlay Mooncake.mooncake_method_table foo(x::Float64) = baz(x) +``` +When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than the original, and should successfully differentiate it. +If you search for `@overlay` in the Mooncake.jl source code, you will see a variety of instances where this is used in practice. + +This approach is often very straightforward, and we recommend you try this first before going down the path of writing rules. + +## Functions with Zero Derivative + +If the above strategy does not work, but you find yourself in the surprisingly common situation that the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following: +```@docs +Mooncake.simple_zero_adjoint +``` +Suppose you have a function `foo(x, y, z)` whose derivative is zero, you would write an `rrule!!` as follows: +```julia +function Mooncake.rrule!!(f::CoDual{typeof(foo)}, x::CoDual, y::CoDual, z::CoDual) + return Mooncake.simple_zero_adjoint(f, x, y, z) +end +``` +Users of ChainRules.jl should be familiar with this functionality -- it is morally the same as `ChainRulesCore.@non_differentiable`. +This approach is utilised often in Mooncake.jl's codebase. + +## Using ChainRules.jl + +[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. +These rules are methods of the `ChainRulesCore.rrule` function. +There are some instances where there is it most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. + +There is enough similarity between these two systems that most of the boilerplate code can be avoided. +The docstrings below explain this functionality, and how it should / should not be used. + +```@docs +Mooncake.@from_rrule +Mooncake.rrule_wrapper +``` diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl new file mode 100644 index 00000000..1bfa63db --- /dev/null +++ b/ext/MooncakeLuxLibExt.jl @@ -0,0 +1,176 @@ +module MooncakeLuxLibExt + +using LuxLib, Random, Mooncake +using Base: IEEEFloat +using Base.Experimental: @overlay + +import LuxLib: Impl +import LuxLib.Utils: static_training_mode_check +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table, CoDual + +@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) +@from_rrule( + DefaultCtx, + Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, +) +@from_rrule( + DefaultCtx, + Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, +) + +# Re-implement a bunch of methods to ensure that Mooncake can differentiate them. +@overlay mooncake_method_table function LuxLib.Impl.fused_dense( + opmode, + act::F, + weight::AbstractMatrix, + x::AbstractMatrix, + b::LuxLib.Optional{<:AbstractVector}, +) where {F} + return bias_activation(act, Impl.matmul(weight, x), b) +end + +@overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!( + y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector +) where {F, xT, yT} + return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias) +end + +@overlay mooncake_method_table function LuxLib.Impl.activation_loop!( + y::AbstractArray, σ::F, x::AbstractArray +) where {F} + return LuxLib.Impl.activation_simd_loop!(y, σ, x) +end + +@overlay mooncake_method_table function LuxLib.Impl.fused_conv( + ::LuxLib.Impl.AbstractInternalArrayOpMode, + act::F, + weight::AbstractArray{wT, N}, + x::AbstractArray{xT, N}, + bias::LuxLib.Optional{<:AbstractVector}, + cdims::LuxLib.Impl.ConvDims, +) where {F, wT, xT, N} + return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) +end + +for f in [ + Impl.SLEEFActivations.sigmoid_fast, + Impl.SLEEFActivations.softplus, + Impl.SLEEFActivations.logsigmoid, + Impl.SLEEFActivations.swish, + Impl.SLEEFActivations.lisht, + Impl.SLEEFActivations.tanh, + Impl.SLEEFActivations.tanh_fast, +] + @from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat} + @from_rrule( + DefaultCtx, + Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}}, + ) +end + +Mooncake.@is_primitive(DefaultCtx, Tuple{typeof(static_training_mode_check), Vararg}) +function Mooncake.rrule!!(f::CoDual{typeof(static_training_mode_check)}, x::CoDual...) + return Mooncake.simple_zero_adjoint(f, x...) +end + + + + +# This is a really horrible hack that we need to do until Mooncake is able to support the +# call-back-into-ad interface that ChainRules exposes. + +import LuxLib.Impl: + safe_eltype, + batchnorm_affine_normalize_internal, + batchnorm_affine_normalize_internal!, + ∇batchnorm_affine_normalize, + AbstractInternalArrayOpMode + +import ChainRulesCore as CRC + +function CRC.rrule( + ::typeof(batchnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, + ::typeof(identity), + x::AbstractArray{T, N}, + μ::AbstractVector, + σ²::AbstractVector, + γ::LuxLib.Optional{<:AbstractVector}, + β::LuxLib.Optional{<:AbstractVector}, + ϵ::Real, +) where {T, N} + y = similar( + x, + promote_type( + safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) + ) + ) + γ′ = similar( + x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1) + ) + + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) + + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′) + ∂∅ = CRC.NoTangent() + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return y, ∇batchnorm_affine_normalize_internal +end + +@from_rrule( + DefaultCtx, + Tuple{ + typeof(batchnorm_affine_normalize_internal), + AbstractInternalArrayOpMode, + typeof(identity), + AbstractArray, + AbstractVector, + AbstractVector, + LuxLib.Optional{<:AbstractVector}, + LuxLib.Optional{<:AbstractVector}, + Real, + }, +) + +@overlay mooncake_method_table function batchnorm_affine_normalize_internal( + opmode::LuxLib.AbstractInternalArrayOpMode, + act::F, + x::AbstractArray{xT, 3}, + μ::AbstractVector, + σ²::AbstractVector, + γ::Union{Nothing, AbstractVector}, + β::Union{Nothing, AbstractVector}, + ϵ::Real, +) where {F, xT} + y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ) + LuxLib.Impl.activation!(y, opmode, act, y) + return y +end + +@overlay mooncake_method_table function batchnorm_affine_normalize_internal( + opmode::LuxLib.AbstractInternalArrayOpMode, + ::typeof(identity), + x::AbstractArray{xT, 3}, + μ::AbstractVector, + σ²::AbstractVector, + γ::Union{Nothing, AbstractVector}, + β::Union{Nothing, AbstractVector}, + ϵ::Real, +) where {xT} + y = similar(x, + promote_type( + safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) + ) + ) + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) + return y +end + +end diff --git a/ext/MooncakeNNlibExt.jl b/ext/MooncakeNNlibExt.jl new file mode 100644 index 00000000..fbd1b2fa --- /dev/null +++ b/ext/MooncakeNNlibExt.jl @@ -0,0 +1,65 @@ +module MooncakeNNlibExt + + using NNlib, Random, Mooncake + using Base: IEEEFloat + using NNlib: dropout + + using NNlib: conv, depthwiseconv + import Mooncake: @from_rrule, DefaultCtx, MinimalCtx + + @from_rrule( + MinimalCtx, + Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, + ) + @from_rrule( + MinimalCtx, + Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, + true, + ) + @from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) + @from_rrule( + MinimalCtx, + Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, + ) + @from_rrule( + MinimalCtx, + Tuple{ + typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, + }, + ) + @from_rrule( + MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} + ) + @from_rrule( + MinimalCtx, + Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + true, + ) + for conv in [:conv, :depthwiseconv] + local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) + + @eval @from_rrule( + MinimalCtx, + Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) + @eval @from_rrule( + MinimalCtx, + Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) + end + @eval @from_rrule( + MinimalCtx, + Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) + for pool in [:maxpool, :meanpool] + @eval @from_rrule( + MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true + ) + end + @from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) +end diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 0e4cd8a4..c9abae09 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -13,6 +13,7 @@ using Random, Setfield +# There are many clashing names, so we will always qualify uses of names from CRC. import ChainRulesCore using Base: @@ -85,7 +86,7 @@ include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) include(joinpath("rrules", "tasks.jl")) -include("chain_rules_macro.jl") +include("chain_rules_interop.jl") include("interface.jl") include("config.jl") diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl new file mode 100644 index 00000000..ed027081 --- /dev/null +++ b/src/chain_rules_interop.jl @@ -0,0 +1,202 @@ + """ + to_cr_tangent(t) + +Convert a Mooncake tangent into a type that ChainRules.jl `rrule`s expect to see. +""" +to_cr_tangent(t::IEEEFloat) = t +to_cr_tangent(t::Array{<:IEEEFloat}) = t +to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() + +""" + increment_and_get_rdata!(fdata, zero_rdata, cr_tangent) + +Increment `fdata` by the fdata component of the ChainRules.jl-style tangent, `cr_tangent`, +and return the rdata component of `cr_tangent` by adding it to `zero_rdata`. +""" +increment_and_get_rdata!(::NoFData, r::T, t::T) where {T<:IEEEFloat} = r + t +function increment_and_get_rdata!(f::Array{P}, ::NoRData, t::Array{P}) where {P<:IEEEFloat} + increment!!(f, t) + return NoRData() +end +increment_and_get_rdata!(::Any, r, ::ChainRulesCore.NoTangent) = r +function increment_and_get_rdata!(f, r, t::ChainRulesCore.Thunk) + return increment_and_get_rdata!(f, r, ChainRulesCore.unthunk(t)) +end + +@doc""" + rrule_wrapper(f::CoDual, args::CoDual...) + +Used to implement `rrule!!`s via `ChainRulesCore.rrule`. + +Given a function `foo`, argument types `arg_types`, and a method `ChainRulesCore.rrule` of +which applies to these, you can make use of this function as follows: +```julia +Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} +function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) + return rrule_wrapper(f, args...) +end +``` +Assumes that methods of `to_cr_tangent` and `to_mooncake_tangent` are defined such that you +can convert between the different representations of tangents that Mooncake and +ChainRulesCore expect. + +Furthermore, it is _essential_ that +1. `f(args)` does not mutate `f` or `args`, and +2. the result of `f(args)` does not alias any data stored in `f` or `args`. + +Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the +amount of boilerplate code that you are required to write even further. +""" +function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} + + # Run forwards-pass. + primals = tuple_map(primal, fargs) + lazy_rdata = tuple_map(Mooncake.lazy_zero_rdata, primals) + y_primal, cr_pb = ChainRulesCore.rrule(primals...) + y_fdata = fdata(zero_tangent(y_primal)) + + function pb!!(y_rdata) + + # Construct tangent w.r.t. output. + cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + + # Run reverse-pass using ChainRules. + cr_dfargs = cr_pb(cr_tangent) + + # Increment fdata and get rdata. + return map(fargs, lazy_rdata, cr_dfargs) do x, l_rdata, cr_dx + return increment_and_get_rdata!(tangent(x), instantiate(l_rdata), cr_dx) + end + end + return CoDual(y_primal, y_fdata), pb!! +end + +function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) where {N} + + # Run forwards-pass. + primals = tuple_map(primal, fargs) + lazy_rdata = tuple_map(lazy_zero_rdata, primals) + y_primal, cr_pb = Core.kwcall(primals[1], ChainRulesCore.rrule, primals[2:end]...) + y_fdata = fdata(zero_tangent(y_primal)) + + function pb!!(y_rdata) + + # Construct tangent w.r.t. output. + cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + + # Run reverse-pass using ChainRules. + cr_dfargs = cr_pb(cr_tangent) + + # Increment fdata and compute rdata. + kwargs_rdata = rdata(zero_tangent(fargs[1])) + args_rdata = map(fargs[2:end], lazy_rdata[2:end], cr_dfargs) do x, l_rdata, cr_dx + return increment_and_get_rdata!(tangent(x), instantiate(l_rdata), cr_dx) + end + return NoRData(), kwargs_rdata, args_rdata... + end + return CoDual(y_primal, y_fdata), pb!! +end + +@doc""" + @from_rrule ctx sig [has_kwargs=false] + +Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. +This macro is a thin wrapper around [`rrule_wrapper`](@ref). + +For example, +```julia +@from_rrule DefaultCtx Tuple{typeof(sin), Float64} +``` +would define a `Mooncake.rrule!!` for `sin` of `Float64`s by calling `ChainRulesCore.rrule`. + +```julia +@from_rrule DefaultCtx Tuple{typeof(foo), Float64} true +``` +would define a method of `Mooncake.rrule!!` which can handle keyword arguments. + +Limitations: it is your responsibility to ensure that +1. calls with signature `sig` do not mutate their arguments, +2. the output of calls with signature `sig` does not alias any of the inputs. + +As with all hand-written rules, you should definitely make use of +[`TestUtils.test_rule`](@ref) to verify correctness on some test cases. + +# A Note On Type Constraints + +Many methods of `ChainRuleCore.rrule` are implemented with very loose type constraints. +For example, it would not be surprising to see a method of rrule with the signature +```julia +Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} +``` +There are a variety of reasons for this way of doing things, and whether it is a good idea +to write rules for such generic objects has been debated at length. + +Suffice it to say, you should not write rules for this package which are so generically +typed. +Rather, you should create rules for the subset of types for which you believe that the +`ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the +rest. +For example, in the above case you might be confident that the rule will behave correctly +for input types `Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}}`. You should therefore +only write a rule for these types: +```julia +@from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} +``` +""" +macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) + + # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. + if sig.head == :curly + @assert sig.args[1] == :Tuple + arg_type_symbols = sig.args[2:end] + where_params = nothing + elseif sig.head == :where + @assert sig.args[1].args[1] == :Tuple + arg_type_symbols = sig.args[1].args[2:end] + where_params = sig.args[2:end] + else + throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) + end + + arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) + arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) + rule_expr = construct_def(arg_names, arg_types, where_params) + + if has_kwargs + kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_symbols...) + kw_sig = where_params === nothing ? kw_sig : Expr(:where, kw_sig, where_params...) + kw_is_primitive = :(Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$kw_sig}) = true) + kwcall_type = :(Mooncake.CoDual{typeof(Core.kwcall)}) + nt_type = :(Mooncake.CoDual{<:NamedTuple}) + kwargs_rule_expr = construct_def( + vcat(:_kwcall, :kwargs, arg_names), + vcat(kwcall_type, nt_type, arg_types), + where_params, + ) + else + kw_is_primitive = nothing + kwargs_rule_expr = nothing + end + + ex = quote + Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true + $rule_expr + $kw_is_primitive + $kwargs_rule_expr + end + return esc(ex) +end + +function construct_def(arg_names, arg_types, where_params) + arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) + def = Dict( + :head => :function, + :name => :(Mooncake.rrule!!), + :args => arg_exprs, + :body => Expr(:call, rrule_wrapper, arg_names...), + ) + if where_params !== nothing + def[:whereparams] = where_params + end + return ExprTools.combinedef(def) +end \ No newline at end of file diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl deleted file mode 100644 index 27006394..00000000 --- a/src/chain_rules_macro.jl +++ /dev/null @@ -1,74 +0,0 @@ -_to_rdata(::ChainRulesCore.NoTangent) = NoRData() -_to_rdata(dx::Float64) = dx - -@doc""" - @from_rrule ctx sig - -Creates a `Mooncake.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in -which this rule should apply, and `sig` is the type-tuple which specifies which primal the -rule should apply to. - -For example, -```julia -@from_rrule DefaultCtx Tuple{typeof(sin), Float64} -``` -would define a `Mooncake.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. - -Health warning: -Use this function with care. It has only been tested for `Float64` arguments and arguments -whose `tangent_type` is `NoTangent`, and it is entirely probable that it won't work for -arguments which aren't `Float64` or non-differentiable. - -You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule created -works as intended. -""" -macro from_rrule(ctx, sig) - - @assert sig.head == :curly - @assert sig.args[1] == :Tuple - arg_type_symbols = sig.args[2:end] - - arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) - arg_types = map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols) - arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - - call_rrule = Expr( - :call, - :(Mooncake.ChainRulesCore.rrule), - map(n -> :(Mooncake.primal($n)), arg_names)..., - ) - - pb_output_names = map(n -> Symbol("dx_$(n)_inc"), eachindex(arg_names)) - - call_pb = Expr(:(=), Expr(:tuple, pb_output_names...), :(pb(dy))) - incrementers = Expr(:tuple, map(b -> :(Mooncake._to_rdata($b)), pb_output_names)...) - - pb = ExprTools.combinedef(Dict( - :head => :function, - :name => :pb!!, - :args => [:dy], - :body => quote - $call_pb - return $incrementers - end, - )) - - rule_expr = ExprTools.combinedef( - Dict( - :head => :function, - :name => :(Mooncake.rrule!!), - :args => arg_exprs, - :body => quote - y, pb = $call_rrule - $pb - return Mooncake.zero_fcodual(y), pb!! - end, - ) - ) - - ex = quote - Mooncake.is_primitive(::Type{$ctx}, ::Type{$sig}) = true - $rule_expr - end - return esc(ex) -end diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 745ed649..90f71621 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -5,7 +5,6 @@ # The most important bit of this code is `inlining_policy` -- the rest is copy + pasted # boiler plate, largely taken from https://github.com/JuliaLang/julia/blob/2fe4190b3d26b4eee52b2b1b1054ddd6e38a941e/test/compiler/newinterp.jl#L11 - struct ClosureCacheKey world_age::UInt key::Any @@ -17,6 +16,8 @@ end MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance, Core.CodeInstance}()) +Base.Experimental.@MethodTable mooncake_method_table + struct MooncakeInterpreter{C} <: CC.AbstractInterpreter meta # additional information world::UInt @@ -25,6 +26,7 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult} code_cache::MooncakeCache oc_cache::Dict{ClosureCacheKey, Any} + method_table_to_overlay::CC.MethodTable function MooncakeInterpreter( ::Type{C}; meta=nothing, @@ -34,8 +36,18 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), oc_cache::Dict{ClosureCacheKey, Any}=Dict{ClosureCacheKey, Any}(), + method_table_to_overlay::CC.MethodTable=mooncake_method_table, ) where {C} - return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) + return new{C}( + meta, + world, + inf_params, + opt_params, + inf_cache, + code_cache, + oc_cache, + method_table_to_overlay, + ) end end @@ -91,6 +103,9 @@ function CC.setindex!( ) return setindex!(wvc.cache.dict, ci, mi) end +function CC.method_table(interp::MooncakeInterpreter) + return CC.OverlayMethodTable(interp.world, interp.method_table_to_overlay) +end _type(x) = x _type(x::CC.Const) = _typeof(x.val) @@ -108,7 +123,9 @@ function CC.inlining_policy( # Do not inline away primitives. argtype_tuple = Tuple{map(_type, argtypes)...} - is_primitive(C, argtype_tuple) && return nothing + if is_primitive(C, argtype_tuple) + return nothing + end # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. return @invoke CC.inlining_policy( diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index bae39e46..9370f11c 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -170,6 +170,9 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) return ir end +Base.iterate(x::CC.MethodLookupResult) = CC.iterate(x) +Base.iterate(x::CC.MethodLookupResult, n::Int) = CC.iterate(x, n) + """ lookup_ir( interp::AbstractInterpreter, @@ -181,18 +184,35 @@ there is no code found, or if more than one `IRCode` instance returned. Returns a tuple containing the `IRCode` and its return type. """ -function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple}) - output = Base.code_ircode_by_type(sig; interp) - if isempty(output) - throw(ArgumentError("No methods found for signature $sig")) - elseif length(output) > 1 - throw(ArgumentError("$(length(output)) methods found for signature $sig")) +function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_until=nothing) + matches = CC.findall(tt, CC.method_table(interp)) + asts = [] + for match in matches.matches + match = match::Core.MethodMatch + meth = Base.func_for_method_checked(match.method, tt, match.sparams) + (code, ty) = CC.typeinf_ircode( + interp, + meth, + match.spec_types, + match.sparams, + optimize_until, + ) + if code === nothing + push!(asts, match.method => Any) + else + push!(asts, code => ty) + end + end + if isempty(asts) + throw(ArgumentError("No methods found for signature $asts")) + elseif length(asts) > 1 + throw(ArgumentError("$(length(asts)) methods found for signature $sig")) end - return only(output) + return only(asts) end -function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance) - return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, nothing) +function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance; optimize_until=nothing) + return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, optimize_until) end """ diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 039d94a0..472c8c0d 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -617,6 +617,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) :leave, :pop_exception, :throw_undef_if_not, + :meta, ] # Expressions which do not require any special treatment. return ad_stmt_info(line, nothing, stmt, nothing) diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 0fd47ac2..1aaf51ad 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -21,6 +21,22 @@ end const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}} +# +# Utility +# + +@is_primitive MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +rrule!!(f::CoDual{typeof(BLAS.get_num_threads)}) = simple_zero_adjoint(f) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +rrule!!(f::CoDual{typeof(BLAS.lbt_get_num_threads)}) = simple_zero_adjoint(f) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} +rrule!!(f::CoDual{typeof(BLAS.set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} +rrule!!(f::CoDual{typeof(BLAS.lbt_set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) + # # LEVEL 1 # @@ -893,6 +909,12 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) test_cases = vcat( + # Utility + (false, :stability, nothing, BLAS.get_num_threads), + (false, :stability, nothing, BLAS.lbt_get_num_threads), + (false, :stability, nothing, BLAS.set_num_threads, 1), + (false, :stability, nothing, BLAS.lbt_set_num_threads, 1), + # # BLAS LEVEL 1 # diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 93c7e17a..26811f6b 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -1,21 +1,21 @@ -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp_fast(primal(x)) - exp_fast_pb!!(dy::Float64) = NoRData(), dy * yp + exp_fast_pb!!(dy::P) = NoRData(), dy * yp return CoDual(yp, NoFData()), exp_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp2_fast(primal(x)) - exp2_fast_pb!!(dy::Float64) = NoRData(), dy * yp * log(2) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2) return CoDual(yp, NoFData()), exp2_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp10_fast(primal(x)) - exp2_fast_pb!!(dy::Float64) = NoRData(), dy * yp * log(10) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10) return CoDual(yp, NoFData()), exp2_fast_pb!! end diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl new file mode 100644 index 00000000..23f7cf52 --- /dev/null +++ b/test/chain_rules_interop.jl @@ -0,0 +1,112 @@ +module ChainRulesInteropTestResources + +using ChainRulesCore, LinearAlgebra, Mooncake + +using Base: IEEEFloat +using Mooncake: DefaultCtx, @from_rrule + +# Test case with isbits data. + +bleh(x::Float64, y::Int) = x * y + +function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) + return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) +end + +@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} false + +# Test case with heap-allocated input. + +test_sum(x) = sum(x) + +function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) + test_sum_pb(dy::Real) = ChainRulesCore.NoTangent(), fill(dy, size(x)) + return test_sum(x), test_sum_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} false + +# Test case with heap-allocated output. + +test_scale(x::Real, y::AbstractVector{<:Real}) = x * y + +function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{<:Real}) + function test_scale_pb(dout::AbstractVector{<:Real}) + return ChainRulesCore.NoTangent(), dot(dout, y), dout * x + end + return x * y, test_scale_pb +end + +@from_rrule( + DefaultCtx, Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}}, false +) + +# Test case with non-differentiable type as output. + +test_nothing() = nothing + +function ChainRulesCore.rrule(::typeof(test_nothing)) + test_nothing_pb(::ChainRulesCore.NoTangent) = (ChainRulesCore.NoTangent(),) + return nothing, test_nothing_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_nothing)} false + +# Test case in which ChainRulesCore returns a tangent which is of the "wrong" type from the +# perspective of Mooncake.jl. In this instance, some kind of error should be thrown, rather +# than it being possible for the error to propagate. + +test_bad_rdata(x::Real) = 5x + +function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) + test_bad_rdata_pb(dy::Float64) = ChainRulesCore.NoTangent(), Float32(dy * 5) + return 5x, test_bad_rdata_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} false + +# Test case for rule with diagonal dispatch. +test_add(x, y) = x + y +function ChainRulesCore.rrule(::typeof(test_add), x, y) + test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout + return x + y, test_add_pb +end +@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} false + +# Test case for rule with kwargs. +test_kwargs(x; y::Bool=false) = y ? x : 2x + +function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool=false) + test_kwargs_pb(dz::Float64) = ChainRulesCore.NoTangent(), y ? dz : 2dz + return y ? x : 2x, test_kwargs_pb +end + +@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs), Float64}, true) + +end + +@testset "chain_rules_macro" begin + @testset "to_cr_tangent" for (t, t_cr) in Any[ + (5.0, 5.0), + (ones(5), ones(5)), + (NoTangent(), ChainRulesCore.NoTangent()), + ] + @test Mooncake.to_cr_tangent(t) == t_cr + end + @testset "rules: $(typeof(fargs))" for fargs in Any[ + (ChainRulesInteropTestResources.bleh, 5.0, 4), + (ChainRulesInteropTestResources.test_sum, ones(5)), + (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), + (ChainRulesInteropTestResources.test_nothing,), + (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (ChainRulesInteropTestResources.test_kwargs, 5.0), + ] + test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) + end + @testset "bad rdata" begin + f = ChainRulesInteropTestResources.test_bad_rdata + out, pb!! = Mooncake.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) + @test_throws MethodError pb!!(5.0) + end +end diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl deleted file mode 100644 index 16f8c6a4..00000000 --- a/test/chain_rules_macro.jl +++ /dev/null @@ -1,11 +0,0 @@ -bleh(x::Float64, y::Int) = x * y - -function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) - return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) -end - -Mooncake.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} - -@testset "chain_rules_macro" begin - Mooncake.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) -end diff --git a/test/integration_testing/cuda.jl b/test/ext/cuda.jl similarity index 100% rename from test/integration_testing/cuda.jl rename to test/ext/cuda.jl diff --git a/test/integration_testing/dynamic_ppl.jl b/test/ext/dynamic_ppl.jl similarity index 100% rename from test/integration_testing/dynamic_ppl.jl rename to test/ext/dynamic_ppl.jl diff --git a/test/integration_testing/logdensityproblemsad_interop.jl b/test/ext/logdensityproblemsad.jl similarity index 100% rename from test/integration_testing/logdensityproblemsad_interop.jl rename to test/ext/logdensityproblemsad.jl diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl new file mode 100644 index 00000000..1748d37f --- /dev/null +++ b/test/ext/luxlib.jl @@ -0,0 +1,40 @@ +@testset "luxlib" begin + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in vcat( + Any[ + (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), + (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + (false, :none, true, LuxLib.Impl.batched_matmul, randn(5, 4, 3), randn(4, 3, 3)), + (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), + ( + false, :none, false, + LuxLib.Impl.activation_loop!, randn(5, 3), NNlib.gelu, randn(5, 3), + ), + (false, :stability_and_allocs, true, SLEEFActivations.sigmoid_fast, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.softplus, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.logsigmoid, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.swish, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.lisht, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.tanh, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.tanh_fast, randn()), + ( + false, :stability_and_allocs, true, + LuxLib.Utils.static_training_mode_check, + nothing, + LuxLib.Utils.True(), + LuxLib.Utils.True(), + ), + ], + vec(map(Iterators.product( + [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + [randn(5), nothing], + [Lux.relu, tanh, NNlib.gelu], + )) do (opmode, bias, activation) + ( + false, :none, false, + LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + ) + end), + ) + test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) + end +end diff --git a/test/ext/nnlib.jl b/test/ext/nnlib.jl new file mode 100644 index 00000000..2a3c3fde --- /dev/null +++ b/test/ext/nnlib.jl @@ -0,0 +1,104 @@ +@testset "nnlib" begin + x = randn(5, 4, 3, 2) + w = randn(2, 2, 3, 3) + dense_cdims = DenseConvDims(x, w) + sep_cdims = DepthwiseConvDims(x, w) + y = conv(x, w, dense_cdims) + y_sep = depthwiseconv(x, w, sep_cdims) + + pool_dims = PoolDims(size(x), 2) + + grid = Array{Float64}(undef, 2, 2, 2, 1) + grid[:, 1, 1, 1] .= (-1, -1) + grid[:, 2, 1, 1] .= (1, -1) + grid[:, 1, 2, 1] .= (-1, 1) + grid[:, 2, 2, 1] .= (1, 1) + + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ + + # batched_mul + (false, :none, true, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + + # dropout + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=1), randn(2, 2), 0.5), + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=2), randn(2, 2), 0.1), + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=(1, 2)), randn(2, 2), 0.4), + + # softmax + (false, :stability, true, softmax, randn(2)), + (false, :stability, true, softmax, randn(2, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), + (false, :none, false, x -> softmax(5x), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=1), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=2), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=(1, 2)), randn(3, 2)), + + # logsoftmax + (false, :stability, true, logsoftmax, randn(2)), + (false, :stability, true, logsoftmax, randn(2, 3)), + (false, :stability, true, logsoftmax, randn(2, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), + + # logsumexp + (false, :stability, true, logsumexp, randn(2,)), + (false, :stability, true, logsumexp, randn(3, 3)), + (false, :stability, true, logsumexp, randn(3, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), + + # upsample_nearest + (false, :stability, true, upsample_nearest, randn(3), (2,)), + (false, :stability, true, upsample_nearest, randn(3, 2), (2, 2)), + (false, :stability, true, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + + # fold + (false, :none, true, NNlib.fold, randn(12, 12, 2), size(x), dense_cdims), + + # unfold + (false, :none, true, NNlib.unfold, x, dense_cdims), + + # scatter + (false, :stability, true, NNlib.scatter, +, randn(2), [1, 3]), + (false, :stability, true, Core.kwcall, (;), NNlib.scatter, +, randn(2), [1, 3]), + + # conv + (false, :none, true, Core.kwcall, (;), conv, x, w, dense_cdims), + (false, :none, true, conv, x, w, dense_cdims), + (false, :none, true, Core.kwcall, (;), depthwiseconv, x, w, sep_cdims), + (false, :none, true, depthwiseconv, x, w, sep_cdims), + + # ∇conv_data + (false, :none, true, Core.kwcall, (;), ∇conv_data, y, w, dense_cdims), + (false, :none, true, ∇conv_data, y, w, dense_cdims), + (false, :none, true, Core.kwcall, (;), ∇depthwiseconv_data, y_sep, w, sep_cdims), + (false, :none, true, ∇depthwiseconv_data, y_sep, w, sep_cdims), + + # ∇conv_filter + (false, :none, true, Core.kwcall, (;), ∇conv_filter, x, y, dense_cdims), + (false, :none, true, ∇conv_filter, x, y, dense_cdims), + + # pooling + (false, :none, true, maxpool, x, pool_dims), + (false, :none, true, Core.kwcall, (;), maxpool, x, pool_dims), + (false, :none, true, meanpool, x, pool_dims), + (false, :none, true, Core.kwcall, (;), meanpool, x, pool_dims), + + # padding + (false, :none, false, x -> pad_constant(x, 1, 2.0), x), + (false, :none, false, x -> pad_constant(x, 1, 2.0; dims=:), x), + ] + @info "$(typeof(fargs))" + test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) + end +end diff --git a/test/integration_testing/special_functions.jl b/test/ext/special_functions.jl similarity index 100% rename from test/integration_testing/special_functions.jl rename to test/ext/special_functions.jl diff --git a/test/front_matter.jl b/test/front_matter.jl index 9b32e431..78fe0a3b 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -5,6 +5,9 @@ using FillArrays, JET, LinearAlgebra, + Lux, + LuxLib, + NNlib, PDMats, Random, SpecialFunctions, @@ -14,12 +17,15 @@ using import ChainRulesCore -using Base: unsafe_load, pointer_from_objref +using Base: unsafe_load, pointer_from_objref, IEEEFloat using Base.Iterators: product using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument using Core.Intrinsics: pointerref, pointerset +using NNlib: dropout +using LuxLib.Impl: SLEEFActivations + using Mooncake: CC, IntrinsicsWrappers, diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl new file mode 100644 index 00000000..15a90245 --- /dev/null +++ b/test/integration_testing/lux.jl @@ -0,0 +1,45 @@ +@testset "lux" begin + @testset "$(typeof(f))" for (f, x_f32) in Any[ + (Dense(2, 4), randn(Float32, 2, 3)), + (Dense(2, 4, gelu), randn(Float32, 2, 3)), + (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + (Scale(2), randn(Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), + (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (InstanceNorm(6), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + ] + @info "$(_typeof((f, x_f32...)))" + ps, st = f32(Lux.setup(sr(123456), f)) + x = f32(x_f32) + test_rule(sr(123456), f, x, ps, st; is_primitive=false, interface_only=true) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5daa8397..75d73aa5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,12 +45,15 @@ include("front_matter.jl") @info "tasks" include(joinpath("rrules", "tasks.jl")) end - include("chain_rules_macro.jl") + include("chain_rules_interop.jl") elseif test_group == "integration_testing/misc" include(joinpath("integration_testing", "battery_tests.jl")) - include(joinpath("integration_testing", "dynamic_ppl.jl")) - include(joinpath("integration_testing", "logdensityproblemsad_interop.jl")) - include(joinpath("integration_testing", "special_functions.jl")) + include(joinpath("ext", "dynamic_ppl.jl")) + include(joinpath("ext", "logdensityproblemsad.jl")) + include(joinpath("ext", "luxlib.jl")) + include(joinpath("ext", "nnlib.jl")) + include(joinpath("ext", "special_functions.jl")) + include(joinpath("integration_testing", "lux.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) elseif test_group == "integration_testing/diff_tests" @@ -66,7 +69,7 @@ include("front_matter.jl") elseif test_group == "integration_testing/temporalgps" include(joinpath("integration_testing", "temporalgps.jl")) elseif test_group == "gpu" - include(joinpath("integration_testing", "cuda.jl")) + include(joinpath("ext", "cuda.jl")) else throw(error("test_group=$(test_group) is not recognised")) end