From cef527fda4a8cb3bae88ba33e2d518081191757a Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 12 Sep 2024 14:26:03 +0100 Subject: [PATCH 01/39] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 255943b8..f135556c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.49" +version = "0.2.50" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From b9c3f650ddbff926482bdb27b68efe426e829d45 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 12 Sep 2024 14:37:49 +0100 Subject: [PATCH 02/39] Fix usage with benchmarktools --- src/interpreter/s2s_reverse_mode_ad.jl | 38 ++++++++++++------------- test/interpreter/s2s_reverse_mode_ad.jl | 6 ++++ 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index b07d79c8..92659df4 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -495,9 +495,9 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) rrule!! # intrinsic / builtin / thing we provably have rule for elseif is_invoke mi = stmt.args[1]::Core.MethodInstance - LazyDerivedRule(info.interp, mi, info.safety_on) # Static dispatch + LazyDerivedRule(mi, info.safety_on) # Static dispatch else - DynamicDerivedRule(info.interp, info.safety_on) # Dynamic dispatch + DynamicDerivedRule(info.safety_on) # Dynamic dispatch end # Wrap the raw rule in a struct which ensures that any `ZeroRData`s are stripped @@ -1420,23 +1420,20 @@ of its arguments. Stores rules in an internal cache to avoid re-deriving. This is used to implement dynamic dispatch. =# -struct DynamicDerivedRule{T, V} - interp::T +struct DynamicDerivedRule{V} cache::V safety_on::Bool end -function DynamicDerivedRule(interp::TapirInterpreter, safety_on::Bool) - return DynamicDerivedRule(interp, Dict{Any, Any}(), safety_on) -end +DynamicDerivedRule(safety_on::Bool) = DynamicDerivedRule(Dict{Any, Any}(), safety_on) -_copy(x::P) where {P<:DynamicDerivedRule} = P(x.interp, Dict{Any, Any}(), x.safety_on) +_copy(x::P) where {P<:DynamicDerivedRule} = P(Dict{Any, Any}(), x.safety_on) function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} sig = Tuple{map(_typeof ∘ primal, args)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing - rule = build_rrule(dynamic_rule.interp, sig; safety_on=dynamic_rule.safety_on) + rule = build_rrule(get_tapir_interpreter(), sig; safety_on=dynamic_rule.safety_on) dynamic_rule.cache[sig] = rule end return rule(args...) @@ -1460,26 +1457,27 @@ reason to keep this around is for debugging -- it is very helpful to have this t in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit. =# -mutable struct LazyDerivedRule{Tinterp<:TapirInterpreter, primal_sig, Trule} - interp::Tinterp +mutable struct LazyDerivedRule{primal_sig, Trule} safety_on::Bool mi::Core.MethodInstance rule::Trule - function LazyDerivedRule(interp::A, mi::Core.MethodInstance, safety_on::Bool) where {A} - return new{A, mi.specTypes, rule_type(interp, mi; safety_on)}(interp, safety_on, mi) + function LazyDerivedRule(mi::Core.MethodInstance, safety_on::Bool) + interp = get_tapir_interpreter() + return new{mi.specTypes, rule_type(interp, mi; safety_on)}(safety_on, mi) end - function LazyDerivedRule{Tinterp, Tprimal_sig, Trule}( - interp::Tinterp, mi::Core.MethodInstance, safety_on::Bool - ) where {Tinterp, Tprimal_sig, Trule} - return new{Tinterp, Tprimal_sig, Trule}(interp, safety_on, mi) + function LazyDerivedRule{Tprimal_sig, Trule}( + mi::Core.MethodInstance, safety_on::Bool + ) where {Tprimal_sig, Trule} + return new{Tprimal_sig, Trule}(safety_on, mi) end end -_copy(x::P) where {P<:LazyDerivedRule} = P(x.interp, x.mi, x.safety_on) +_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.safety_on) -function (rule::LazyDerivedRule{T, sig, Trule})(args::Vararg{Any, N}) where {N, T, sig, Trule} +function (rule::LazyDerivedRule{sig, Trule})(args::Vararg{Any, N}) where {N, sig, Trule} if !isdefined(rule, :rule) - derived_rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on) + interp = get_tapir_interpreter() + derived_rule = build_rrule(interp, rule.mi; safety_on=rule.safety_on) if derived_rule isa Trule rule.rule = derived_rule else diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 75b5a9ba..be9cd820 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -274,5 +274,11 @@ end Xoshiro(123456), S2SGlobals.f, S2SGlobals.A(2 * ones(3)), ones(3); interface_only=false, is_primitive=false, ) + + # BenchmarkTools not working due to world age problems. Provided that this code + # runs successfully, everything is okay -- no need to check anything specific. + f(x) = sin(cos(x)) + rule = Tapir.build_rrule(f, 0.0) + @benchmark Tapir.value_and_gradient!!($rule, $f, $(Ref(0.0))[]) end end From 8f0f75d13e35f338f203ba47863ac6d593e5d5d0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 11:41:31 +0100 Subject: [PATCH 03/39] Initial pass --- src/Tapir.jl | 1 + src/chain_rules_macro.jl | 81 +++++++++++++++++++++++++-------------- test/chain_rules_macro.jl | 17 ++++++++ 3 files changed, 71 insertions(+), 28 deletions(-) diff --git a/src/Tapir.jl b/src/Tapir.jl index f4656ce3..89d8f61a 100644 --- a/src/Tapir.jl +++ b/src/Tapir.jl @@ -13,6 +13,7 @@ using Random, Setfield +# There are many clashing names, so we will always qualify uses of names from CRC. import ChainRulesCore using Base: diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index de5c64a0..502feffe 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -1,5 +1,55 @@ -_to_rdata(::ChainRulesCore.NoTangent) = NoRData() -_to_rdata(dx::Float64) = dx +""" + to_cr_tangent(t) + +Convert a Tapir tangent into a type that ChainRules.jl `rrule`s expect to see. +Inverse of `to_tapir_tangent`. +""" +to_cr_tangent(t::IEEEFloat) = t +to_cr_tangent(t::Array{<:IEEEFloat}) = t +to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() + +""" + to_tapir_tangent(cr_t) + +Convert a ChainRules.jl tangent, `cr_t`, into the corresponding Tapir tangent. +Inverse of `to_cr_tangent`. +""" +to_tapir_tangent(t::IEEEFloat) = t +to_tapir_tangent(t::Array{<:IEEEFloat}) = t +to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() + +""" + rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} + +Used to implement `rrule!!`s via `ChainRulesCore.rrule`. + +Given a function `foo`, argument types `arg_types`, and a method `ChainRulesCore.rrule` of +which applies to these, you can make use of this function as follows: +```julia +Tapir.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} +function Tapir.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) + return rrule_wrapper_implementation(f, args...) +end +``` +Assumes that methods of `to_cr_tangent` and `to_tapir_tangent` are defined such that you +can convert between the different representations of tangents that Tapir and ChainRulesCore +expect. + +Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the +amount of boilerplate code that you are required to write even further. +""" +function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} + y_primal, cr_pb = ChainRulesCore.rrule(tuple_map(primal, fargs)...) + y_fdata = fdata(zero_tangent(y_primal)) + function pb!!(y_rdata) + cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + cr_dfargs = cr_pb(cr_tangent) + dfargs = tuple_map(to_tapir_tangent, cr_dfargs) + tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) + return tuple_map(rdata, dfargs) + end + return CoDual(y_primal, y_fdata), pb!! +end @doc""" @from_rrule ctx sig @@ -32,37 +82,12 @@ macro from_rrule(ctx, sig) arg_types = map(t -> :(Tapir.CoDual{<:$t}), arg_type_symbols) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - call_rrule = Expr( - :call, - :(Tapir.ChainRulesCore.rrule), - map(n -> :(Tapir.primal($n)), arg_names)..., - ) - - pb_output_names = map(n -> Symbol("dx_$(n)_inc"), eachindex(arg_names)) - - call_pb = Expr(:(=), Expr(:tuple, pb_output_names...), :(pb(dy))) - incrementers = Expr(:tuple, map(b -> :(Tapir._to_rdata($b)), pb_output_names)...) - - pb = ExprTools.combinedef(Dict( - :head => :function, - :name => :pb!!, - :args => [:dy], - :body => quote - $call_pb - return $incrementers - end, - )) - rule_expr = ExprTools.combinedef( Dict( :head => :function, :name => :(Tapir.rrule!!), :args => arg_exprs, - :body => quote - y, pb = $call_rrule - $pb - return Tapir.zero_fcodual(y), pb!! - end, + :body => Expr(:call, rrule_wrapper_implementation, arg_names...), ) ) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 920cd0dd..98f07161 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -1,3 +1,5 @@ +# Test case with isbits data. + bleh(x::Float64, y::Int) = x * y function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) @@ -6,6 +8,21 @@ end Tapir.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +# Test case with heap-allocated data. + +test_sum(x) = sum(x) + +function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) + test_sum_pb(dy::Real) = ChainRulesCore.NoTangent(), fill(dy, size(x)) + return test_sum(x), test_sum_pb +end + +Tapir.@is_primitive DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} +function Tapir.rrule!!(f::CoDual{typeof(test_sum)}, x::CoDual{<:Array{<:Base.IEEEFloat}}) + return Tapir.rrule_wrapper_implementation(f, x) +end + @testset "chain_rules_macro" begin Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) + Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) end From e791ceff1af48d4d86dbe379c28279657a4292c3 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 11:41:42 +0100 Subject: [PATCH 04/39] Bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f135556c..65294a51 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.50" +version = "0.2.51" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From f45456e4947760a32a94bfe457c262b7c62bcb54 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 12:42:59 +0100 Subject: [PATCH 05/39] Unit test to_tapir_tangent and to_cr_tangent --- test/chain_rules_macro.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 98f07161..ad4b5ab3 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -23,6 +23,18 @@ function Tapir.rrule!!(f::CoDual{typeof(test_sum)}, x::CoDual{<:Array{<:Base.IEE end @testset "chain_rules_macro" begin - Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) - Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) + @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ + (5.0, 5.0), + (ones(5), ones(5)), + (NoTangent(), ChainRulesCore.NoTangent()), + ] + @test Tapir.to_cr_tangent(t) == t_cr + @test Tapir.to_tapir_tangent(t_cr) == t + @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t + @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr + end + @testset "rules" begin + Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) + Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) + end end From bec9f06ef919be3e40c4d1bf859ff09de6df0579 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 12:43:49 +0100 Subject: [PATCH 06/39] Make use of macro --- test/chain_rules_macro.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index ad4b5ab3..4b185732 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -17,10 +17,7 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -Tapir.@is_primitive DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} -function Tapir.rrule!!(f::CoDual{typeof(test_sum)}, x::CoDual{<:Array{<:Base.IEEEFloat}}) - return Tapir.rrule_wrapper_implementation(f, x) -end +Tapir.@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} @testset "chain_rules_macro" begin @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ From d03710178cec3b79804ebbf47e4fb4d09da13996 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 12:59:43 +0100 Subject: [PATCH 07/39] More testing and tidying up --- src/chain_rules_macro.jl | 2 +- test/chain_rules_macro.jl | 49 ++++++++++++++++++++++++++++++++++----- test/front_matter.jl | 2 +- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 502feffe..097830a4 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -92,7 +92,7 @@ macro from_rrule(ctx, sig) ) ex = quote - Tapir.is_primitive(::Type{$ctx}, ::Type{$sig}) = true + Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true $rule_expr end return esc(ex) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 4b185732..9ab1ec98 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -1,3 +1,10 @@ +module ChainRulesInteropTestResources + +using ChainRulesCore, LinearAlgebra, Tapir + +using Base: IEEEFloat +using Tapir: DefaultCtx, @from_rrule + # Test case with isbits data. bleh(x::Float64, y::Int) = x * y @@ -6,9 +13,9 @@ function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) end -Tapir.@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} -# Test case with heap-allocated data. +# Test case with heap-allocated input. test_sum(x) = sum(x) @@ -17,7 +24,33 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -Tapir.@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} +@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} + +# Test case with heap-allocated output. + +test_scale(x::Real, y::AbstractVector{<:Real}) = x * y + +function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{<:Real}) + function test_scale_pb(dout::AbstractVector{<:Real}) + return ChainRulesCore.NoTangent(), dot(dout, y), dout * x + end + return x * y, test_scale_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}} + +# Test case with non-differentiable type as output. + +test_nothing() = nothing + +function ChainRulesCore.rrule(::typeof(test_nothing)) + test_nothing_pb(::ChainRulesCore.NoTangent) = (ChainRulesCore.NoTangent(),) + return nothing, test_nothing_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_nothing)} + +end @testset "chain_rules_macro" begin @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ @@ -30,8 +63,12 @@ Tapir.@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr end - @testset "rules" begin - Tapir.TestUtils.test_rule(Xoshiro(1), bleh, 5.0, 4; perf_flag=:stability) - Tapir.TestUtils.test_rule(Xoshiro(1), test_sum, ones(5); perf_flag=:stability) + @testset "rules: $(typeof(fargs))" for fargs in Any[ + (ChainRulesInteropTestResources.bleh, 5.0, 4), + (ChainRulesInteropTestResources.test_sum, ones(5)), + (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), + (ChainRulesInteropTestResources.test_nothing,), + ] + test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end end diff --git a/test/front_matter.jl b/test/front_matter.jl index b42f6693..f78ca47e 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -14,7 +14,7 @@ using import ChainRulesCore -using Base: unsafe_load, pointer_from_objref +using Base: unsafe_load, pointer_from_objref, IEEEFloat using Base.Iterators: product using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument From 54947f0b69b59288abbcde7bb1ff21ccc4386faa Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 13:10:35 +0100 Subject: [PATCH 08/39] Add some basic type checking and a test --- src/chain_rules_macro.jl | 7 +++++-- test/chain_rules_macro.jl | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 097830a4..5fab9da4 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -39,12 +39,15 @@ Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to redu amount of boilerplate code that you are required to write even further. """ function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} - y_primal, cr_pb = ChainRulesCore.rrule(tuple_map(primal, fargs)...) + primals = tuple_map(primal, fargs) + tangent_types = tuple_map(x -> tangent_type(typeof(x)), primals) + y_primal, cr_pb = ChainRulesCore.rrule(primals...) y_fdata = fdata(zero_tangent(y_primal)) function pb!!(y_rdata) cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) cr_dfargs = cr_pb(cr_tangent) - dfargs = tuple_map(to_tapir_tangent, cr_dfargs) + dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) + dfargs = tuple_map(typeassert, dfargs_unvalidated, tangent_types) tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) return tuple_map(rdata, dfargs) end diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 9ab1ec98..7512b3e9 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -50,6 +50,19 @@ end @from_rrule DefaultCtx Tuple{typeof(test_nothing)} +# Test case in which ChainRulesCore returns a tangent which is of the "wrong" type from the +# perspective of Tapir.jl. In this instance, some kind of error should be thrown, rather +# than it being possible for the error to propagate. + +test_bad_rdata(x::Real) = 5x + +function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) + test_bad_rdata_pb(dy::Float64) = ChainRulesCore.NoTangent(), Float32(dy * 5) + return 5x, test_bad_rdata_pb +end + +@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} + end @testset "chain_rules_macro" begin @@ -71,4 +84,9 @@ end ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end + @testset "bad rdata" begin + f = ChainRulesInteropTestResources.test_bad_rdata + out, pb!! = Tapir.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) + @test_throws TypeError pb!!(5.0) + end end From bc88483ea66be7df16f1b5afcddbea4ec0c5c710 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 13:19:27 +0100 Subject: [PATCH 09/39] Improve formatting and commenting --- src/chain_rules_macro.jl | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 5fab9da4..b72f21b7 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -38,17 +38,39 @@ expect. Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the amount of boilerplate code that you are required to write even further. """ -function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} +@inline function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} + + # Run forwards-pass. primals = tuple_map(primal, fargs) - tangent_types = tuple_map(x -> tangent_type(typeof(x)), primals) y_primal, cr_pb = ChainRulesCore.rrule(primals...) y_fdata = fdata(zero_tangent(y_primal)) + + # Construct functions which, when applied to the tangent types returned on the + # reverse-pass, will check that they are of the expected type. This will pick up on + # obvious problems, but is intended to be fast / optimised away when things go well. + # As such, you should think of this as a lightweight version of "debug_mode". + tangent_type_assertions = tuple_map( + x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals + ) + function pb!!(y_rdata) + + # Construct tangent w.r.t. output. cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + + # Run reverse-pass using ChainRules. cr_dfargs = cr_pb(cr_tangent) + + # Convert output into tangent types appropriate for Tapir. dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) - dfargs = tuple_map(typeassert, dfargs_unvalidated, tangent_types) + + # Apply type assertions. + dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) + + # Increment the fdata. tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) + + # Return the rdata. return tuple_map(rdata, dfargs) end return CoDual(y_primal, y_fdata), pb!! From f29b8f31c3c7ee022f89da101d1a59cf205dfe39 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 14:22:10 +0100 Subject: [PATCH 10/39] Formatting --- src/chain_rules_macro.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index b72f21b7..10316993 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -94,8 +94,8 @@ Use this function with care. It has only been tested for `Float64` arguments and whose `tangent_type` is `NoTangent`, and it is entirely probable that it won't work for arguments which aren't `Float64` or non-differentiable. -You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule created -works as intended. +You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule +created works as intended. """ macro from_rrule(ctx, sig) From 50d7dd83ade66e4dd1ba434736bc2e6b721c3ce6 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 14:31:34 +0100 Subject: [PATCH 11/39] Improve documentation --- src/chain_rules_macro.jl | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 10316993..9185eb57 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -18,8 +18,8 @@ to_tapir_tangent(t::IEEEFloat) = t to_tapir_tangent(t::Array{<:IEEEFloat}) = t to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() -""" - rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} +@doc""" + rrule_wrapper_implementation(f::CoDual, args::CoDual...) Used to implement `rrule!!`s via `ChainRulesCore.rrule`. @@ -35,6 +35,10 @@ Assumes that methods of `to_cr_tangent` and `to_tapir_tangent` are defined such can convert between the different representations of tangents that Tapir and ChainRulesCore expect. +Furthermore, it is _essential_ that +1. `f(args)` does not mutate `f` or `args`, and +2. the result of `f(args)` does not alias any data stored in `f` or `args`. + Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the amount of boilerplate code that you are required to write even further. """ @@ -79,9 +83,8 @@ end @doc""" @from_rrule ctx sig -Creates a `Tapir.rrule!!` from a `ChainRulesCore.rrule`. `ctx` is the type of the context in -which this rule should apply, and `sig` is the type-tuple which specifies which primal the -rule should apply to. +Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. +This macro is a thin wrapper around [`rrule_wrapper_implementation`](@ref). For example, ```julia @@ -89,13 +92,16 @@ For example, ``` would define a `Tapir.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. -Health warning: -Use this function with care. It has only been tested for `Float64` arguments and arguments -whose `tangent_type` is `NoTangent`, and it is entirely probable that it won't work for -arguments which aren't `Float64` or non-differentiable. +Limitations: it is your responsibility to ensure that +1. calls with signature `sig` do not mutate their arguments, +2. the output of calls with signature `sig` does not alias any of the inputs, +3. `sig` is a `Tuple{...}`, not a `Tuple{...} where {...}`. + +This last point is a limitation of the current implementation, rather than something +fundamental, whereas the first two points are more basic points. -You should definitely make use of [`TestUtils.test_rule`](@ref) to verify that the rule -created works as intended. +As with all hand-written rules, you should definitely make use of +[`TestUtils.test_rule`](@ref) to verify correctness on some test cases. """ macro from_rrule(ctx, sig) From 1788c07d6298fd4121c2ec1a6b0cc54dfb0f1049 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 14:52:10 +0100 Subject: [PATCH 12/39] Explain how not to use rrule functionality --- docs/make.jl | 7 +++++-- docs/src/using_chain_rules.md | 13 +++++++++++++ src/chain_rules_macro.jl | 22 ++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 docs/src/using_chain_rules.md diff --git a/docs/make.jl b/docs/make.jl index 88f35278..42aa87b3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -32,9 +32,12 @@ makedocs( "Algorithmic Differentiation" => "algorithmic_differentiation.md", "Tapir.jl's Rule System" => "mathematical_interpretation.md", ], + "Utilities" => [ + "Using ChainRules" => "using_chain_rules.md", + "Safe Mode" => "safe_mode.md", + "Debugging and MWEs" => "debugging_and_mwes.md", + ], "Known Limitations" => "known_limitations.md", - "Safe Mode" => "safe_mode.md", - "Debugging and MWEs" => "debugging_and_mwes.md", ] ) diff --git a/docs/src/using_chain_rules.md b/docs/src/using_chain_rules.md new file mode 100644 index 00000000..531d4c4c --- /dev/null +++ b/docs/src/using_chain_rules.md @@ -0,0 +1,13 @@ +# Using ChainRules.jl + +[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. +These rules are methods of the `ChainRulesCore.rrule` function. +There are some instances where there is it most convenient to implement a `Tapir.rrule!!` by wrapping an existing `ChainRulesCore.rrule`. + +There is enough similarity between these two systems that most of the boilerplate code can be avoided. +The docstrings below explain this functionality, and how it should / should not be used. + +```@docs +Tapir.@from_rrule +Tapir.rrule_wrapper_implementation +``` diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 9185eb57..9a625125 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -102,6 +102,28 @@ fundamental, whereas the first two points are more basic points. As with all hand-written rules, you should definitely make use of [`TestUtils.test_rule`](@ref) to verify correctness on some test cases. + +# A Note On Type Constraints + +Many methods of `ChainRuleCore.rrule` are implemented with very loose type constraints. +For example, it would not be surprising to see a method of rrule with the signature +```julia +Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}} +``` +There are a variety of reasons for this way of doing things, and whether it is a good idea +to write rules for such generic objects has been debated at length. + +Suffice it to say, you should not write rules for this package which are so generically +typed. +Rather, you should create rules for the subset of types for which you believe that the +`ChainRulesCore.rrule` will work correctly, and leave this package to derive rules for the +rest. +For example, in the above case you might be confident that the rule will behave correctly +for input types `Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}}`. You should therefore +only write a rule for these types: +```julia +@from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} +``` """ macro from_rrule(ctx, sig) From b4e80bc0ca8f5f27c8756eab103f55d8b2f714c9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 15:05:39 +0100 Subject: [PATCH 13/39] Add rules for BLAS utilities --- src/rrules/blas.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 5b15fb08..5a995485 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -19,7 +19,21 @@ function tri!(A, u::Char, d::Char) return u == 'L' ? tril!(A, d == 'U' ? -1 : 0) : triu!(A, d == 'U' ? 1 : 0) end +# +# Utility +# + +@is_primitive MinimalCtx Tuple{typeof(BLAS.get_num_threads)} +rrule!!(f::CoDual{typeof(BLAS.get_num_threads)}) = simple_zero_adjoint(f) +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_get_num_threads)} +rrule!!(f::CoDual{typeof(BLAS.lbt_get_num_threads)}) = simple_zero_adjoint(f) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.set_num_threads), Union{Integer, Nothing}} +rrule!!(f::CoDual{typeof(BLAS.set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) + +@is_primitive MinimalCtx Tuple{typeof(BLAS.lbt_set_num_threads), Any} +rrule!!(f::CoDual{typeof(BLAS.lbt_set_num_threads)}, x::CoDual) = simple_zero_adjoint(f, x) # # LEVEL 1 @@ -793,6 +807,12 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) test_cases = vcat( + # Utility + (false, :stability, nothing, BLAS.get_num_threads), + (false, :stability, nothing, BLAS.lbt_get_num_threads), + (false, :stability, nothing, BLAS.set_num_threads, 1), + (false, :stability, nothing, BLAS.lbt_set_num_threads, 1), + # # BLAS LEVEL 1 # From 4a2b8e0890d019ad54ba636b8391a9f1d73b66a2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 16:01:27 +0100 Subject: [PATCH 14/39] Initial NNlib integration --- Project.toml | 6 +++++- ext/TapirNNlibExt.jl | 12 ++++++++++++ test/front_matter.jl | 1 + test/integration_testing/nnlib.jl | 9 +++++++++ test/runtests.jl | 1 + 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 ext/TapirNNlibExt.jl create mode 100644 test/integration_testing/nnlib.jl diff --git a/Project.toml b/Project.toml index 65294a51..3e63a0a8 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] @@ -29,6 +30,7 @@ TapirCUDAExt = "CUDA" TapirDynamicPPLExt = "DynamicPPL" TapirJETExt = "JET" TapirLogDensityProblemsADExt = "LogDensityProblemsAD" +TapirNNlibExt = "NNlib" TapirSpecialFunctionsExt = "SpecialFunctions" [compat] @@ -47,6 +49,7 @@ Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" MistyClosures = "1" +NNlib = "0.9" PDMats = "0.11" Setfield = "1" SpecialFunctions = "2" @@ -67,6 +70,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -76,4 +80,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "NNlib", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl new file mode 100644 index 00000000..2b074a7d --- /dev/null +++ b/ext/TapirNNlibExt.jl @@ -0,0 +1,12 @@ +module TapirNNlibExt + + using NNlib, Tapir + using Base: IEEEFloat + + import Tapir: @from_rrule, DefaultCtx + + @from_rrule( + DefaultCtx, + Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, + ) +end diff --git a/test/front_matter.jl b/test/front_matter.jl index f78ca47e..ec9bbdab 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -5,6 +5,7 @@ using FillArrays, JET, LinearAlgebra, + NNlib, PDMats, Random, SpecialFunctions, diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl new file mode 100644 index 00000000..b1215077 --- /dev/null +++ b/test/integration_testing/nnlib.jl @@ -0,0 +1,9 @@ +@testset "nnlib" begin + @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ + (:stability, NNlib.upsample_nearest, randn(3), (2,)), + (:stability, NNlib.upsample_nearest, randn(3, 2), (2, 2)), + (:stability, NNlib.upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + ] + test_rule(sr(1), fargs...; is_primitive=true, perf_flag) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e921cc49..73cb208f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -49,6 +49,7 @@ include("front_matter.jl") include(joinpath("integration_testing", "battery_tests.jl")) include(joinpath("integration_testing", "dynamic_ppl.jl")) include(joinpath("integration_testing", "logdensityproblemsad_interop.jl")) + include(joinpath("integration_testing", "nnlib.jl")) include(joinpath("integration_testing", "special_functions.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) From d1d9fae42b7e83124b07d87b3f37b7d57bd4ae83 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 16:15:13 +0100 Subject: [PATCH 15/39] Thunks and batched_mul --- src/chain_rules_macro.jl | 1 + test/chain_rules_macro.jl | 6 ++++++ test/integration_testing/nnlib.jl | 1 + 3 files changed, 8 insertions(+) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index 9a625125..a3e88f39 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -17,6 +17,7 @@ Inverse of `to_cr_tangent`. to_tapir_tangent(t::IEEEFloat) = t to_tapir_tangent(t::Array{<:IEEEFloat}) = t to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() +to_tapir_tangent(t::ChainRulesCore.Thunk) = to_tapir_tangent(ChainRulesCore.unthunk(t)) @doc""" rrule_wrapper_implementation(f::CoDual, args::CoDual...) diff --git a/test/chain_rules_macro.jl b/test/chain_rules_macro.jl index 7512b3e9..e0c85d4b 100644 --- a/test/chain_rules_macro.jl +++ b/test/chain_rules_macro.jl @@ -76,6 +76,12 @@ end @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr end + + # The fact that I'm testing this separately suggests to me that there's something that + # I've not quite gotten right about the abstractions involved here. + @testset "ChainRulesCore.thunk" begin + @test Tapir.to_tapir_tangent(ChainRulesCore.Thunk(() -> ones(5))) == ones(5) + end @testset "rules: $(typeof(fargs))" for fargs in Any[ (ChainRulesInteropTestResources.bleh, 5.0, 4), (ChainRulesInteropTestResources.test_sum, ones(5)), diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl index b1215077..11644344 100644 --- a/test/integration_testing/nnlib.jl +++ b/test/integration_testing/nnlib.jl @@ -1,5 +1,6 @@ @testset "nnlib" begin @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ + (:none, NNlib.batched_mul, randn(3, 2, 3), randn(2, 5, 3)), (:stability, NNlib.upsample_nearest, randn(3), (2,)), (:stability, NNlib.upsample_nearest, randn(3, 2), (2, 2)), (:stability, NNlib.upsample_nearest, randn(3, 2, 3), (2, 2, 5)), From 6f036adcab39bf3b18105cf2f7cd22c228e17f60 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 17:18:23 +0100 Subject: [PATCH 16/39] More rules + kwargs + rename --- ext/TapirNNlibExt.jl | 15 ++++++++++ src/chain_rules_macro.jl | 49 +++++++++++++++++++++++++++---- test/integration_testing/nnlib.jl | 39 ++++++++++++++++++++---- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 2b074a7d..75306f6e 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -5,6 +5,21 @@ module TapirNNlibExt import Tapir: @from_rrule, DefaultCtx + @from_rrule( + DefaultCtx, Tuple{typeof(batched_mul), Array{<:IEEEFloat, 3}, Array{<:IEEEFloat, 3}} + ) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(softmax), Array{<:IEEEFloat}}, + ) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsoftmax), Array{<:IEEEFloat}}, + ) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsumexp), Array{<:IEEEFloat}}, + ) @from_rrule( DefaultCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, diff --git a/src/chain_rules_macro.jl b/src/chain_rules_macro.jl index a3e88f39..5983908e 100644 --- a/src/chain_rules_macro.jl +++ b/src/chain_rules_macro.jl @@ -20,7 +20,7 @@ to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() to_tapir_tangent(t::ChainRulesCore.Thunk) = to_tapir_tangent(ChainRulesCore.unthunk(t)) @doc""" - rrule_wrapper_implementation(f::CoDual, args::CoDual...) + rrule_wrapper(f::CoDual, args::CoDual...) Used to implement `rrule!!`s via `ChainRulesCore.rrule`. @@ -29,7 +29,7 @@ which applies to these, you can make use of this function as follows: ```julia Tapir.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...} function Tapir.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...) - return rrule_wrapper_implementation(f, args...) + return rrule_wrapper(f, args...) end ``` Assumes that methods of `to_cr_tangent` and `to_tapir_tangent` are defined such that you @@ -43,7 +43,7 @@ Furthermore, it is _essential_ that Subject to some constraints, you can use the [`@from_rrule`](@ref) macro to reduce the amount of boilerplate code that you are required to write even further. """ -@inline function rrule_wrapper_implementation(fargs::Vararg{CoDual, N}) where {N} +function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} # Run forwards-pass. primals = tuple_map(primal, fargs) @@ -81,11 +81,50 @@ amount of boilerplate code that you are required to write even further. return CoDual(y_primal, y_fdata), pb!! end +function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) where {N} + + # Run forwards-pass. + primals = tuple_map(primal, fargs) + y_primal, cr_pb = Core.kwcall(primals[1], ChainRulesCore.rrule, primals[2:end]...) + y_fdata = fdata(zero_tangent(y_primal)) + + # Construct functions which, when applied to the tangent types returned on the + # reverse-pass, will check that they are of the expected type. This will pick up on + # obvious problems, but is intended to be fast / optimised away when things go well. + # As such, you should think of this as a lightweight version of "debug_mode". + tangent_type_assertions = tuple_map( + x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals[2:end] + ) + + function pb!!(y_rdata) + + # Construct tangent w.r.t. output. + cr_tangent = to_cr_tangent(tangent(y_fdata, y_rdata)) + + # Run reverse-pass using ChainRules. + cr_dfargs = cr_pb(cr_tangent) + + # Convert output into tangent types appropriate for Tapir. + dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) + + # Apply type assertions. + dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) + + # Increment the fdata. + tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs[2:end], dfargs) + + # Return the rdata. + kwargs_rdata = rdata(zero_tangent(fargs[1])) + return NoRData(), kwargs_rdata, tuple_map(rdata, dfargs)... + end + return CoDual(y_primal, y_fdata), pb!! +end + @doc""" @from_rrule ctx sig Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. -This macro is a thin wrapper around [`rrule_wrapper_implementation`](@ref). +This macro is a thin wrapper around [`rrule_wrapper`](@ref). For example, ```julia @@ -141,7 +180,7 @@ macro from_rrule(ctx, sig) :head => :function, :name => :(Tapir.rrule!!), :args => arg_exprs, - :body => Expr(:call, rrule_wrapper_implementation, arg_names...), + :body => Expr(:call, rrule_wrapper, arg_names...), ) ) diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl index 11644344..4a615774 100644 --- a/test/integration_testing/nnlib.jl +++ b/test/integration_testing/nnlib.jl @@ -1,10 +1,39 @@ @testset "nnlib" begin @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ - (:none, NNlib.batched_mul, randn(3, 2, 3), randn(2, 5, 3)), - (:stability, NNlib.upsample_nearest, randn(3), (2,)), - (:stability, NNlib.upsample_nearest, randn(3, 2), (2, 2)), - (:stability, NNlib.upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + + # batched_mul + (:none, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + + # softmax + (:stability, Core.kwcall, (dims=1, ), softmax, randn(2,)), + (:stability, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), + (:none, x -> softmax(5x), randn(3, 2)), + (:none, x -> softmax(x; dims=1), randn(3, 2)), + (:none, x -> softmax(x; dims=2), randn(3, 2)), + (:none, x -> softmax(x; dims=(1, 2)), randn(3, 2)), + + # logsoftmax + (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), + (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), + + # logsumexp + (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), + (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), + (:stability, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), + (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), + + # upsample_nearest + (:stability, upsample_nearest, randn(3), (2,)), + (:stability, upsample_nearest, randn(3, 2), (2, 2)), + (:stability, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), ] - test_rule(sr(1), fargs...; is_primitive=true, perf_flag) + test_rule(sr(1), fargs...; is_primitive=false, perf_flag) end end From e225a0adafea729179d8255598ac8bad9873c2b9 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 18:33:04 +0100 Subject: [PATCH 17/39] Fix link in docs --- docs/src/using_chain_rules.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/using_chain_rules.md b/docs/src/using_chain_rules.md index 531d4c4c..e1726aad 100644 --- a/docs/src/using_chain_rules.md +++ b/docs/src/using_chain_rules.md @@ -9,5 +9,5 @@ The docstrings below explain this functionality, and how it should / should not ```@docs Tapir.@from_rrule -Tapir.rrule_wrapper_implementation +Tapir.rrule_wrapper ``` From 3bba38ebbc9d53dd80e43574199c8ca9570eb6e0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 13 Sep 2024 18:36:35 +0100 Subject: [PATCH 18/39] Rename chain_rules_macro to chain_rules_interop --- src/{chain_rules_macro.jl => chain_rules_interop.jl} | 0 test/{chain_rules_macro.jl => chain_rules_interop.jl} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/{chain_rules_macro.jl => chain_rules_interop.jl} (100%) rename test/{chain_rules_macro.jl => chain_rules_interop.jl} (100%) diff --git a/src/chain_rules_macro.jl b/src/chain_rules_interop.jl similarity index 100% rename from src/chain_rules_macro.jl rename to src/chain_rules_interop.jl diff --git a/test/chain_rules_macro.jl b/test/chain_rules_interop.jl similarity index 100% rename from test/chain_rules_macro.jl rename to test/chain_rules_interop.jl From 619f0ce9ed2bb069fb05b28b61884e402d6bb97b Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 11:56:24 +0100 Subject: [PATCH 19/39] Complete rename of chain rules interop file --- src/Tapir.jl | 2 +- test/runtests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Tapir.jl b/src/Tapir.jl index 89d8f61a..1af51cb0 100644 --- a/src/Tapir.jl +++ b/src/Tapir.jl @@ -86,7 +86,7 @@ include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) include(joinpath("rrules", "tasks.jl")) -include("chain_rules_macro.jl") +include("chain_rules_interop.jl") include("interface.jl") export diff --git a/test/runtests.jl b/test/runtests.jl index 73cb208f..619f6ea9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("front_matter.jl") @info "tasks" include(joinpath("rrules", "tasks.jl")) end - include("chain_rules_macro.jl") + include("chain_rules_interop.jl") elseif test_group == "integration_testing/misc" include(joinpath("integration_testing", "battery_tests.jl")) include(joinpath("integration_testing", "dynamic_ppl.jl")) From 345c46a0e7d58338b56a13c999e3817ebae7c919 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 11:58:37 +0100 Subject: [PATCH 20/39] Refactor chain rules interop --- src/chain_rules_interop.jl | 69 ++++++++++++------------------------- test/chain_rules_interop.jl | 14 +++++++- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 5983908e..4bf5c90c 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -1,23 +1,25 @@ -""" + """ to_cr_tangent(t) Convert a Tapir tangent into a type that ChainRules.jl `rrule`s expect to see. -Inverse of `to_tapir_tangent`. """ to_cr_tangent(t::IEEEFloat) = t to_cr_tangent(t::Array{<:IEEEFloat}) = t to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() """ - to_tapir_tangent(cr_t) + increment_and_get_rdata!(fdata, rdata, cr_tangent) -Convert a ChainRules.jl tangent, `cr_t`, into the corresponding Tapir tangent. -Inverse of `to_cr_tangent`. """ -to_tapir_tangent(t::IEEEFloat) = t -to_tapir_tangent(t::Array{<:IEEEFloat}) = t -to_tapir_tangent(::ChainRulesCore.NoTangent) = NoTangent() -to_tapir_tangent(t::ChainRulesCore.Thunk) = to_tapir_tangent(ChainRulesCore.unthunk(t)) +increment_and_get_rdata!(::NoFData, r::T, t::T) where {T<:IEEEFloat} = r + t +function increment_and_get_rdata!(f::Array{P}, ::NoRData, t::Array{P}) where {P<:IEEEFloat} + increment!!(f, t) + return NoRData() +end +increment_and_get_rdata!(::Any, r, ::ChainRulesCore.NoTangent) = r +function increment_and_get_rdata!(f, r, t::ChainRulesCore.Thunk) + return increment_and_get_rdata!(f, r, ChainRulesCore.unthunk(t)) +end @doc""" rrule_wrapper(f::CoDual, args::CoDual...) @@ -47,17 +49,10 @@ function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} # Run forwards-pass. primals = tuple_map(primal, fargs) + lazy_rdata = tuple_map(Tapir.lazy_zero_rdata, primals) y_primal, cr_pb = ChainRulesCore.rrule(primals...) y_fdata = fdata(zero_tangent(y_primal)) - # Construct functions which, when applied to the tangent types returned on the - # reverse-pass, will check that they are of the expected type. This will pick up on - # obvious problems, but is intended to be fast / optimised away when things go well. - # As such, you should think of this as a lightweight version of "debug_mode". - tangent_type_assertions = tuple_map( - x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals - ) - function pb!!(y_rdata) # Construct tangent w.r.t. output. @@ -66,17 +61,10 @@ function rrule_wrapper(fargs::Vararg{CoDual, N}) where {N} # Run reverse-pass using ChainRules. cr_dfargs = cr_pb(cr_tangent) - # Convert output into tangent types appropriate for Tapir. - dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) - - # Apply type assertions. - dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) - - # Increment the fdata. - tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs, dfargs) - - # Return the rdata. - return tuple_map(rdata, dfargs) + # Increment fdata and get rdata. + return map(fargs, lazy_rdata, cr_dfargs) do x, l_rdata, cr_dx + return increment_and_get_rdata!(tangent(x), instantiate(l_rdata), cr_dx) + end end return CoDual(y_primal, y_fdata), pb!! end @@ -85,17 +73,10 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) # Run forwards-pass. primals = tuple_map(primal, fargs) + lazy_rdata = tuple_map(lazy_zero_rdata, primals) y_primal, cr_pb = Core.kwcall(primals[1], ChainRulesCore.rrule, primals[2:end]...) y_fdata = fdata(zero_tangent(y_primal)) - # Construct functions which, when applied to the tangent types returned on the - # reverse-pass, will check that they are of the expected type. This will pick up on - # obvious problems, but is intended to be fast / optimised away when things go well. - # As such, you should think of this as a lightweight version of "debug_mode". - tangent_type_assertions = tuple_map( - x -> Base.Fix2(typeassert, tangent_type(typeof(x))), primals[2:end] - ) - function pb!!(y_rdata) # Construct tangent w.r.t. output. @@ -104,18 +85,12 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) # Run reverse-pass using ChainRules. cr_dfargs = cr_pb(cr_tangent) - # Convert output into tangent types appropriate for Tapir. - dfargs_unvalidated = tuple_map(to_tapir_tangent, cr_dfargs) - - # Apply type assertions. - dfargs = tuple_map((x, T) -> T(x), dfargs_unvalidated, tangent_type_assertions) - - # Increment the fdata. - tuple_map((x, dx) -> increment!!(tangent(x), fdata(dx)), fargs[2:end], dfargs) - - # Return the rdata. + # Increment fdata and compute rdata. kwargs_rdata = rdata(zero_tangent(fargs[1])) - return NoRData(), kwargs_rdata, tuple_map(rdata, dfargs)... + args_rdata = map(fargs[2:end], lazy_rdata[2:end], cr_dfargs) do x, l_rdata, cr_dx + return increment_and_get_rdata!(tangent(x), instantiate(l_rdata), cr_dx) + end + return NoRData(), kwargs_rdata, args_rdata... end return CoDual(y_primal, y_fdata), pb!! end diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index e0c85d4b..66c56bec 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -63,6 +63,16 @@ end @from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} +# Test case for rule with kwargs. +test_kwargs(x; y::Bool) = y ? x : 2x + +function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool) + test_kwargs_pb(dz::Float64) = ChainRulesCore.NoTangent(), y ? dz : 2dz + return y ? x : 2x, test_kwargs_pb +end + +@from_rrule DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(test_kwargs), Float64} + end @testset "chain_rules_macro" begin @@ -87,12 +97,14 @@ end (ChainRulesInteropTestResources.test_sum, ones(5)), (ChainRulesInteropTestResources.test_scale, 5.0, randn(3)), (ChainRulesInteropTestResources.test_nothing,), + (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end @testset "bad rdata" begin f = ChainRulesInteropTestResources.test_bad_rdata out, pb!! = Tapir.rrule!!(zero_fcodual(f), zero_fcodual(3.0)) - @test_throws TypeError pb!!(5.0) + @test_throws MethodError pb!!(5.0) end end From 8e87d116983c0c8d3d0bc874ecdafb56b3faa1ad Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 11:58:46 +0100 Subject: [PATCH 21/39] Add more nnlib functionality --- ext/TapirNNlibExt.jl | 82 +++++++++++++++++++++++- test/integration_testing/nnlib.jl | 101 ++++++++++++++++++++++-------- 2 files changed, 157 insertions(+), 26 deletions(-) diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 75306f6e..8edd69f4 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -1,6 +1,6 @@ module TapirNNlibExt - using NNlib, Tapir + using NNlib, Random, Tapir using Base: IEEEFloat import Tapir: @from_rrule, DefaultCtx @@ -8,14 +8,28 @@ module TapirNNlibExt @from_rrule( DefaultCtx, Tuple{typeof(batched_mul), Array{<:IEEEFloat, 3}, Array{<:IEEEFloat, 3}} ) + @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof(dropout), + AbstractRNG, + Array{<:IEEEFloat}, + IEEEFloat, + }, + ) + @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}) @from_rrule( DefaultCtx, Tuple{typeof(Core.kwcall), NamedTuple, typeof(softmax), Array{<:IEEEFloat}}, ) + @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}) @from_rrule( DefaultCtx, Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsoftmax), Array{<:IEEEFloat}}, ) + @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}) @from_rrule( DefaultCtx, Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsumexp), Array{<:IEEEFloat}}, @@ -24,4 +38,70 @@ module TapirNNlibExt DefaultCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, ) + @from_rrule( + DefaultCtx, + Tuple{ + typeof(NNlib.fold), + Array{<:IEEEFloat}, + NTuple{N, Int} where {N}, + DenseConvDims, + }, + ) + @from_rrule(DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims}) + @from_rrule( + DefaultCtx, + Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + ) + @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof(NNlib.scatter), + Any, + Array, + Array{<:Union{Integer, Tuple}}, + }, + ) + + for backend in (Symbol(), :_direct, :_im2col), name in (:conv, :depthwiseconv) + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof(NNlib.$(Symbol("$name$(backend)"))), + Array{<:IEEEFloat}, + Array{<:IEEEFloat}, + ConvDims, + }, + ) + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(NNlib.$(Symbol("$name$(backend)"))), + Array{<:IEEEFloat}, + Array{<:IEEEFloat}, + ConvDims, + }, + ) + end + for pool in [:maxpool, :meanpool] + @eval @from_rrule(DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}) + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(Core.kwcall), + NamedTuple, + typeof($pool), + Array{<:IEEEFloat}, + PoolDims, + }, + ) + end + @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}) + @from_rrule( + DefaultCtx, + Tuple{typeof(Core.kwcall), NamedTuple, typeof(pad_constant), Array, Any, Any}, + ) end diff --git a/test/integration_testing/nnlib.jl b/test/integration_testing/nnlib.jl index 4a615774..0cfdbb4b 100644 --- a/test/integration_testing/nnlib.jl +++ b/test/integration_testing/nnlib.jl @@ -1,39 +1,90 @@ @testset "nnlib" begin - @testset "$(typeof(fargs))" for (perf_flag, fargs...) in Any[ + x = randn(5, 4, 3, 2) + w = randn(2, 2, 3, 3) + dense_cdims = DenseConvDims(x, w) + sep_cdims = DepthwiseConvDims(x, w) + pool_dims = PoolDims(size(x), 2) + + grid = Array{Float64}(undef, 2, 2, 2, 1) + grid[:, 1, 1, 1] .= (-1, -1) + grid[:, 2, 1, 1] .= (1, -1) + grid[:, 1, 2, 1] .= (-1, 1) + grid[:, 2, 2, 1] .= (1, 1) + + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ # batched_mul - (:none, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + (false, :none, true, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), + + # dropout + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=1), randn(2, 2), 0.5), + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=2), randn(2, 2), 0.1), + (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=(1, 2)), randn(2, 2), 0.4), # softmax - (:stability, Core.kwcall, (dims=1, ), softmax, randn(2,)), - (:stability, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), - (:none, x -> softmax(5x), randn(3, 2)), - (:none, x -> softmax(x; dims=1), randn(3, 2)), - (:none, x -> softmax(x; dims=2), randn(3, 2)), - (:none, x -> softmax(x; dims=(1, 2)), randn(3, 2)), + (false, :stability, true, softmax, randn(2)), + (false, :stability, true, softmax, randn(2, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), softmax, randn(3, 3, 2)), + (false, :none, false, x -> softmax(5x), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=1), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=2), randn(3, 2)), + (false, :none, false, x -> softmax(x; dims=(1, 2)), randn(3, 2)), # logsoftmax - (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), - (:stability, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), + (false, :stability, true, logsoftmax, randn(2)), + (false, :stability, true, logsoftmax, randn(2, 3)), + (false, :stability, true, logsoftmax, randn(2, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsoftmax, randn(3, 3, 2)), # logsumexp - (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), - (:stability, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), - (:stability, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), - (:stability, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), + (false, :stability, true, logsumexp, randn(2,)), + (false, :stability, true, logsumexp, randn(3, 3)), + (false, :stability, true, logsumexp, randn(3, 3, 2)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(2,)), + (false, :stability, true, Core.kwcall, (dims=1, ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=2, ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3)), + (false, :stability, true, Core.kwcall, (dims=(1, 2), ), logsumexp, randn(3, 3, 2)), # upsample_nearest - (:stability, upsample_nearest, randn(3), (2,)), - (:stability, upsample_nearest, randn(3, 2), (2, 2)), - (:stability, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + (false, :stability, true, upsample_nearest, randn(3), (2,)), + (false, :stability, true, upsample_nearest, randn(3, 2), (2, 2)), + (false, :stability, true, upsample_nearest, randn(3, 2, 3), (2, 2, 5)), + + # fold + (false, :none, true, NNlib.fold, randn(12, 12, 2), size(x), dense_cdims), + + # unfold + (false, :none, true, NNlib.unfold, x, dense_cdims), + + # scatter + (false, :stability, true, NNlib.scatter, +, randn(2), [1, 3]), + (false, :stability, true, Core.kwcall, (;), NNlib.scatter, +, randn(2), [1, 3]), + + # conv + (false, :none, true, Core.kwcall, (;), conv, x, w, dense_cdims), + (false, :none, true, conv, x, w, dense_cdims), + (false, :none, true, Core.kwcall, (;), depthwiseconv, x, w, sep_cdims), + (false, :none, true, depthwiseconv, x, w, sep_cdims), + + # pooling + (false, :none, true, maxpool, x, pool_dims), + (false, :none, true, Core.kwcall, (;), maxpool, x, pool_dims), + (false, :none, true, meanpool, x, pool_dims), + (false, :none, true, Core.kwcall, (;), meanpool, x, pool_dims), + + # padding + (false, :none, false, x -> pad_constant(x, 1, 2.0), x), + (false, :none, false, x -> pad_constant(x, 1, 2.0; dims=:), x), ] - test_rule(sr(1), fargs...; is_primitive=false, perf_flag) + test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end end From d3459782d6cc66d03cc5581694a524d5ddcce4b2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 14:00:13 +0100 Subject: [PATCH 22/39] Remove old tests --- test/chain_rules_interop.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 66c56bec..368d683e 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -76,15 +76,12 @@ end end @testset "chain_rules_macro" begin - @testset "to_cr_tangent and to_tapir_tangent" for (t, t_cr) in Any[ + @testset "to_cr_tangent" for (t, t_cr) in Any[ (5.0, 5.0), (ones(5), ones(5)), (NoTangent(), ChainRulesCore.NoTangent()), ] @test Tapir.to_cr_tangent(t) == t_cr - @test Tapir.to_tapir_tangent(t_cr) == t - @test Tapir.to_tapir_tangent(Tapir.to_cr_tangent(t)) == t - @test Tapir.to_cr_tangent(Tapir.to_tapir_tangent(t_cr)) == t_cr end # The fact that I'm testing this separately suggests to me that there's something that From 0f3fe90af1280eb2f3c3b0aff81211067fa0a943 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 16 Sep 2024 19:12:20 +0100 Subject: [PATCH 23/39] Some work --- Project.toml | 7 ++- ext/TapirLuxLibExt.jl | 21 +++++++++ src/interpreter/s2s_reverse_mode_ad.jl | 2 + test/{integration_testing => ext}/cuda.jl | 0 .../dynamic_ppl.jl | 0 .../logdensityproblemsad.jl} | 0 test/ext/luxlib.jl | 8 ++++ test/{integration_testing => ext}/nnlib.jl | 0 .../special_functions.jl | 0 test/front_matter.jl | 2 + test/integration_testing/lux.jl | 44 +++++++++++++++++++ test/runtests.jl | 11 ++--- 12 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 ext/TapirLuxLibExt.jl rename test/{integration_testing => ext}/cuda.jl (100%) rename test/{integration_testing => ext}/dynamic_ppl.jl (100%) rename test/{integration_testing/logdensityproblemsad_interop.jl => ext/logdensityproblemsad.jl} (100%) create mode 100644 test/ext/luxlib.jl rename test/{integration_testing => ext}/nnlib.jl (100%) rename test/{integration_testing => ext}/special_functions.jl (100%) create mode 100644 test/integration_testing/lux.jl diff --git a/Project.toml b/Project.toml index 3e63a0a8..70b7646c 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -30,6 +31,7 @@ TapirCUDAExt = "CUDA" TapirDynamicPPLExt = "DynamicPPL" TapirJETExt = "JET" TapirLogDensityProblemsADExt = "LogDensityProblemsAD" +TapirLuxLibExt = "LuxLib" TapirNNlibExt = "NNlib" TapirSpecialFunctionsExt = "SpecialFunctions" @@ -48,6 +50,7 @@ FillArrays = "1" Graphs = "1" JET = "0.9" LogDensityProblemsAD = "1" +LuxLib = "1.2" MistyClosures = "1" NNlib = "0.9" PDMats = "0.11" @@ -70,6 +73,8 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -80,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "NNlib", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] +test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"] diff --git a/ext/TapirLuxLibExt.jl b/ext/TapirLuxLibExt.jl new file mode 100644 index 00000000..83fe6230 --- /dev/null +++ b/ext/TapirLuxLibExt.jl @@ -0,0 +1,21 @@ +module TapirLuxLibExt + + using LuxLib, Random, Tapir + using Base: IEEEFloat + + import LuxLib.Impl: matmul, matmuladd, fused_dense + import Tapir: @from_rrule, DefaultCtx + + @from_rrule DefaultCtx Tuple{typeof(matmul), Array{<:IEEEFloat}, Array{<:IEEEFloat}} + @from_rrule( + DefaultCtx, + Tuple{typeof(matmuladd), Array{<:IEEEFloat}, Array{<:IEEEFloat}, Vector{<:IEEEFloat}}, + ) + + # The implementations of rrules for fused operations are not straightforward to + # incorporate into Tapir.jl, because they call back into AD. + # We take a simple appoach to their implementation: differentiate an un-fused version + # of their implementation. This will likely hit performance, but it makes implementing + # rules much more straightforward, in that we only have to be able to implement their + # constituent parts, rather than the entire thing. +end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 92659df4..217750aa 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -843,6 +843,8 @@ function build_rrule( interp::TapirInterpreter{C}, sig_or_mi; safety_on=false, silence_safety_messages=true ) where {C} + @show sig_or_mi + # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. if Base.get_world_counter() > interp.world diff --git a/test/integration_testing/cuda.jl b/test/ext/cuda.jl similarity index 100% rename from test/integration_testing/cuda.jl rename to test/ext/cuda.jl diff --git a/test/integration_testing/dynamic_ppl.jl b/test/ext/dynamic_ppl.jl similarity index 100% rename from test/integration_testing/dynamic_ppl.jl rename to test/ext/dynamic_ppl.jl diff --git a/test/integration_testing/logdensityproblemsad_interop.jl b/test/ext/logdensityproblemsad.jl similarity index 100% rename from test/integration_testing/logdensityproblemsad_interop.jl rename to test/ext/logdensityproblemsad.jl diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl new file mode 100644 index 00000000..ce0e15c5 --- /dev/null +++ b/test/ext/luxlib.jl @@ -0,0 +1,8 @@ +@testset "luxlib" begin + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ + (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), + (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + ] + test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) + end +end diff --git a/test/integration_testing/nnlib.jl b/test/ext/nnlib.jl similarity index 100% rename from test/integration_testing/nnlib.jl rename to test/ext/nnlib.jl diff --git a/test/integration_testing/special_functions.jl b/test/ext/special_functions.jl similarity index 100% rename from test/integration_testing/special_functions.jl rename to test/ext/special_functions.jl diff --git a/test/front_matter.jl b/test/front_matter.jl index ec9bbdab..8bb01a51 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -5,6 +5,8 @@ using FillArrays, JET, LinearAlgebra, + Lux, + LuxLib, NNlib, PDMats, Random, diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl new file mode 100644 index 00000000..3f4778e4 --- /dev/null +++ b/test/integration_testing/lux.jl @@ -0,0 +1,44 @@ +@testset "lux" begin + @testset "$(typeof(f))" for (f, x_f32) in Any[ + (Dense(2, 4), randn(Float32, 2, 3)), + (Dense(2, 4, gelu), randn(Float32, 2, 3)), + # (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + # (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + # (Scale(2), randn(Float32, 2, 3)), + # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule + # (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule + # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + # (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule + # (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow + # (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow + # (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), # fpext getting used here somehow + # (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), # fpext getting used here somehow + # (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + # (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + # (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression + # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # stack overflow. Probably task again + # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again + # (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again + # (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + ] + @info "$(_typeof((f, x_f32...)))" + ps, st = f64(Lux.setup(sr(123456), f)) + x = f64(x_f32) + test_rule(sr(123456), f, x, ps, st; is_primitive=false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 619f6ea9..fc168188 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -47,10 +47,11 @@ include("front_matter.jl") include("chain_rules_interop.jl") elseif test_group == "integration_testing/misc" include(joinpath("integration_testing", "battery_tests.jl")) - include(joinpath("integration_testing", "dynamic_ppl.jl")) - include(joinpath("integration_testing", "logdensityproblemsad_interop.jl")) - include(joinpath("integration_testing", "nnlib.jl")) - include(joinpath("integration_testing", "special_functions.jl")) + include(joinpath("ext", "dynamic_ppl.jl")) + include(joinpath("ext", "logdensityproblemsad.jl")) + include(joinpath("ext", "luxlib.jl")) + include(joinpath("ext", "nnlib.jl")) + include(joinpath("ext", "special_functions.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) elseif test_group == "integration_testing/diff_tests" @@ -66,7 +67,7 @@ include("front_matter.jl") elseif test_group == "integration_testing/temporalgps" include(joinpath("integration_testing", "temporalgps.jl")) elseif test_group == "gpu" - include(joinpath("integration_testing", "cuda.jl")) + include(joinpath("ext", "cuda.jl")) else throw(error("test_group=$(test_group) is not recognised")) end From ae93a27fcf1b105b69145e875e2082743f780b10 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:23:39 +0100 Subject: [PATCH 24/39] Remove errant show statment --- src/interpreter/s2s_reverse_mode_ad.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 217750aa..92659df4 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -843,8 +843,6 @@ function build_rrule( interp::TapirInterpreter{C}, sig_or_mi; safety_on=false, silence_safety_messages=true ) where {C} - @show sig_or_mi - # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. if Base.get_world_counter() > interp.world From 82ecd82e41758debbac91c7f6fef78ac49d6e8cb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:36:23 +0100 Subject: [PATCH 25/39] Remove redundant test --- test/chain_rules_interop.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 368d683e..5b2f72b8 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -83,12 +83,6 @@ end ] @test Tapir.to_cr_tangent(t) == t_cr end - - # The fact that I'm testing this separately suggests to me that there's something that - # I've not quite gotten right about the abstractions involved here. - @testset "ChainRulesCore.thunk" begin - @test Tapir.to_tapir_tangent(ChainRulesCore.Thunk(() -> ones(5))) == ones(5) - end @testset "rules: $(typeof(fargs))" for fargs in Any[ (ChainRulesInteropTestResources.bleh, 5.0, 4), (ChainRulesInteropTestResources.test_sum, ones(5)), From ca93535d1f055f3ee461780518423f8d59daff23 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:43:05 +0100 Subject: [PATCH 26/39] Support where --- src/chain_rules_interop.jl | 36 +++++++++++++++++++++--------------- test/chain_rules_interop.jl | 8 ++++++++ 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 4bf5c90c..8abc3e76 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -109,11 +109,7 @@ would define a `Tapir.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCo Limitations: it is your responsibility to ensure that 1. calls with signature `sig` do not mutate their arguments, -2. the output of calls with signature `sig` does not alias any of the inputs, -3. `sig` is a `Tuple{...}`, not a `Tuple{...} where {...}`. - -This last point is a limitation of the current implementation, rather than something -fundamental, whereas the first two points are more basic points. +2. the output of calls with signature `sig` does not alias any of the inputs. As with all hand-written rules, you should definitely make use of [`TestUtils.test_rule`](@ref) to verify correctness on some test cases. @@ -142,22 +138,32 @@ only write a rule for these types: """ macro from_rrule(ctx, sig) - @assert sig.head == :curly - @assert sig.args[1] == :Tuple - arg_type_symbols = sig.args[2:end] + if sig.head == :curly + @assert sig.args[1] == :Tuple + arg_type_symbols = sig.args[2:end] + where_params = nothing + elseif sig.head == :where + @assert sig.args[1].args[1] == :Tuple + arg_type_symbols = sig.args[1].args[2:end] + where_params = sig.args[2:end] + else + throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) + end arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) arg_types = map(t -> :(Tapir.CoDual{<:$t}), arg_type_symbols) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) - rule_expr = ExprTools.combinedef( - Dict( - :head => :function, - :name => :(Tapir.rrule!!), - :args => arg_exprs, - :body => Expr(:call, rrule_wrapper, arg_names...), - ) + def = Dict( + :head => :function, + :name => :(Tapir.rrule!!), + :args => arg_exprs, + :body => Expr(:call, rrule_wrapper, arg_names...), ) + if where_params !== nothing + def[:whereparams] = where_params + end + rule_expr = ExprTools.combinedef(def) ex = quote Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 5b2f72b8..039398bc 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -63,6 +63,14 @@ end @from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} +# Test case for rule with diagonal dispatch. +test_add(x, y) = x + y +function ChainRulesCore.rrule(::typeof(test_add), x, y) + test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout + return x + y, test_add_pb +end +@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} + # Test case for rule with kwargs. test_kwargs(x; y::Bool) = y ? x : 2x From fc6c00fcf9a4ea1646a667f1057c1b44b1aaf5ae Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 13:58:23 +0100 Subject: [PATCH 27/39] Make use of where params --- ext/TapirNNlibExt.jl | 23 +++++++++++------------ test/ext/nnlib.jl | 1 + test/front_matter.jl | 2 ++ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 8edd69f4..4a1f2675 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -2,11 +2,13 @@ module TapirNNlibExt using NNlib, Random, Tapir using Base: IEEEFloat + using NNlib: dropout import Tapir: @from_rrule, DefaultCtx @from_rrule( - DefaultCtx, Tuple{typeof(batched_mul), Array{<:IEEEFloat, 3}, Array{<:IEEEFloat, 3}} + DefaultCtx, + Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, ) @from_rrule( DefaultCtx, @@ -15,9 +17,9 @@ module TapirNNlibExt NamedTuple, typeof(dropout), AbstractRNG, - Array{<:IEEEFloat}, - IEEEFloat, - }, + Array{P}, + P, + } where {P<:IEEEFloat}, ) @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}) @from_rrule( @@ -71,19 +73,16 @@ module TapirNNlibExt typeof(Core.kwcall), NamedTuple, typeof(NNlib.$(Symbol("$name$(backend)"))), - Array{<:IEEEFloat}, - Array{<:IEEEFloat}, + Array{P}, + Array{P}, ConvDims, - }, + } where {P<:IEEEFloat}, ) @eval @from_rrule( DefaultCtx, Tuple{ - typeof(NNlib.$(Symbol("$name$(backend)"))), - Array{<:IEEEFloat}, - Array{<:IEEEFloat}, - ConvDims, - }, + typeof(NNlib.$(Symbol("$name$(backend)"))), Array{P}, Array{P}, ConvDims, + } where {P<:IEEEFloat}, ) end for pool in [:maxpool, :meanpool] diff --git a/test/ext/nnlib.jl b/test/ext/nnlib.jl index 0cfdbb4b..1f00d9f5 100644 --- a/test/ext/nnlib.jl +++ b/test/ext/nnlib.jl @@ -85,6 +85,7 @@ (false, :none, false, x -> pad_constant(x, 1, 2.0), x), (false, :none, false, x -> pad_constant(x, 1, 2.0; dims=:), x), ] + @info "$(typeof(fargs))" test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end end diff --git a/test/front_matter.jl b/test/front_matter.jl index 8bb01a51..56ac344e 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -23,6 +23,8 @@ using Core: bitcast, svec, ReturnNode, PhiNode, PiNode, GotoIfNot, GotoNode, SSAValue, Argument using Core.Intrinsics: pointerref, pointerset +using NNlib: dropout + using Tapir: CC, IntrinsicsWrappers, From 473bc0288cd5d50882a9ad99a09b65688c3cf4f6 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 14:53:31 +0100 Subject: [PATCH 28/39] Improve kwarg interface --- ext/TapirLuxLibExt.jl | 4 +- ext/TapirNNlibExt.jl | 76 ++++++------------------------------- src/chain_rules_interop.jl | 49 ++++++++++++++++++------ test/chain_rules_interop.jl | 16 ++++---- 4 files changed, 61 insertions(+), 84 deletions(-) diff --git a/ext/TapirLuxLibExt.jl b/ext/TapirLuxLibExt.jl index 83fe6230..37eb7272 100644 --- a/ext/TapirLuxLibExt.jl +++ b/ext/TapirLuxLibExt.jl @@ -6,10 +6,10 @@ module TapirLuxLibExt import LuxLib.Impl: matmul, matmuladd, fused_dense import Tapir: @from_rrule, DefaultCtx - @from_rrule DefaultCtx Tuple{typeof(matmul), Array{<:IEEEFloat}, Array{<:IEEEFloat}} + @from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( DefaultCtx, - Tuple{typeof(matmuladd), Array{<:IEEEFloat}, Array{<:IEEEFloat}, Vector{<:IEEEFloat}}, + Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, ) # The implementations of rrules for fused operations are not straightforward to diff --git a/ext/TapirNNlibExt.jl b/ext/TapirNNlibExt.jl index 4a1f2675..5bc4b1f1 100644 --- a/ext/TapirNNlibExt.jl +++ b/ext/TapirNNlibExt.jl @@ -12,30 +12,12 @@ module TapirNNlibExt ) @from_rrule( DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof(dropout), - AbstractRNG, - Array{P}, - P, - } where {P<:IEEEFloat}, - ) - @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(softmax), Array{<:IEEEFloat}}, - ) - @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsoftmax), Array{<:IEEEFloat}}, - ) - @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(logsumexp), Array{<:IEEEFloat}}, + Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, + true, ) + @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) + @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) + @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) @from_rrule( DefaultCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, @@ -43,64 +25,30 @@ module TapirNNlibExt @from_rrule( DefaultCtx, Tuple{ - typeof(NNlib.fold), - Array{<:IEEEFloat}, - NTuple{N, Int} where {N}, - DenseConvDims, + typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, }, ) - @from_rrule(DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims}) @from_rrule( - DefaultCtx, - Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} ) @from_rrule( DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof(NNlib.scatter), - Any, - Array, - Array{<:Union{Integer, Tuple}}, - }, + Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, + true, ) - for backend in (Symbol(), :_direct, :_im2col), name in (:conv, :depthwiseconv) - @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof(NNlib.$(Symbol("$name$(backend)"))), - Array{P}, - Array{P}, - ConvDims, - } where {P<:IEEEFloat}, - ) @eval @from_rrule( DefaultCtx, Tuple{ typeof(NNlib.$(Symbol("$name$(backend)"))), Array{P}, Array{P}, ConvDims, } where {P<:IEEEFloat}, + true, ) end for pool in [:maxpool, :meanpool] - @eval @from_rrule(DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}) @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(Core.kwcall), - NamedTuple, - typeof($pool), - Array{<:IEEEFloat}, - PoolDims, - }, + DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true ) end - @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}) - @from_rrule( - DefaultCtx, - Tuple{typeof(Core.kwcall), NamedTuple, typeof(pad_constant), Array, Any, Any}, - ) + @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) end diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 8abc3e76..8beba59d 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -96,7 +96,7 @@ function rrule_wrapper(::CoDual{typeof(Core.kwcall)}, fargs::Vararg{CoDual, N}) end @doc""" - @from_rrule ctx sig + @from_rrule ctx sig [has_kwargs=false] Convenience functionality to assist in using `ChainRulesCore.rrule`s to write `rrule!!`s. This macro is a thin wrapper around [`rrule_wrapper`](@ref). @@ -107,6 +107,11 @@ For example, ``` would define a `Tapir.rrule!!` for `sin` of `Float64`s, by calling `ChainRulesCore.rrule`. +```julia +@from_rrule DefaultCtx Tuple{typeof(foo), Float64} true +``` +would define a method of `Tapir.rrule!!` which can handle keyword arguments. + Limitations: it is your responsibility to ensure that 1. calls with signature `sig` do not mutate their arguments, 2. the output of calls with signature `sig` does not alias any of the inputs. @@ -136,8 +141,9 @@ only write a rule for these types: @from_rrule DefaultCtx Tuple{typeof(foo), IEEEFloat, Vector{<:IEEEFloat}} ``` """ -macro from_rrule(ctx, sig) +macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) + # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. if sig.head == :curly @assert sig.args[1] == :Tuple arg_type_symbols = sig.args[2:end] @@ -152,8 +158,35 @@ macro from_rrule(ctx, sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) arg_types = map(t -> :(Tapir.CoDual{<:$t}), arg_type_symbols) - arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) + rule_expr = construct_def(arg_names, arg_types, where_params) + + if has_kwargs + kw_sig = Expr(:curly, :Tuple, :(typeof(Core.kwcall)), :NamedTuple, arg_type_symbols...) + kw_sig = where_params === nothing ? kw_sig : Expr(:where, kw_sig, where_params...) + kw_is_primitive = :(Tapir.is_primitive(::Type{$ctx}, ::Type{<:$kw_sig}) = true) + kwcall_type = :(Tapir.CoDual{typeof(Core.kwcall)}) + nt_type = :(Tapir.CoDual{<:NamedTuple}) + kwargs_rule_expr = construct_def( + vcat(:_kwcall, :kwargs, arg_names), + vcat(kwcall_type, nt_type, arg_types), + where_params, + ) + else + kw_is_primitive = nothing + kwargs_rule_expr = nothing + end + ex = quote + Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true + $rule_expr + $kw_is_primitive + $kwargs_rule_expr + end + return esc(ex) +end + +function construct_def(arg_names, arg_types, where_params) + arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) def = Dict( :head => :function, :name => :(Tapir.rrule!!), @@ -163,11 +196,5 @@ macro from_rrule(ctx, sig) if where_params !== nothing def[:whereparams] = where_params end - rule_expr = ExprTools.combinedef(def) - - ex = quote - Tapir.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true - $rule_expr - end - return esc(ex) -end + return ExprTools.combinedef(def) +end \ No newline at end of file diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index 039398bc..bdb10fd1 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -13,7 +13,7 @@ function ChainRulesCore.rrule(::typeof(bleh), x::Float64, y::Int) return x * y, dz -> (ChainRulesCore.NoTangent(), dz * y, ChainRulesCore.NoTangent()) end -@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} +@from_rrule DefaultCtx Tuple{typeof(bleh), Float64, Int} false # Test case with heap-allocated input. @@ -24,7 +24,7 @@ function ChainRulesCore.rrule(::typeof(test_sum), x::AbstractArray{<:Real}) return test_sum(x), test_sum_pb end -@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} +@from_rrule DefaultCtx Tuple{typeof(test_sum), Array{<:Base.IEEEFloat}} false # Test case with heap-allocated output. @@ -37,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(test_scale), x::Real, y::AbstractVector{< return x * y, test_scale_pb end -@from_rrule DefaultCtx Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}} +@from_rrule( + DefaultCtx, Tuple{typeof(test_scale), Base.IEEEFloat, Vector{<:Base.IEEEFloat}}, false +) # Test case with non-differentiable type as output. @@ -48,7 +50,7 @@ function ChainRulesCore.rrule(::typeof(test_nothing)) return nothing, test_nothing_pb end -@from_rrule DefaultCtx Tuple{typeof(test_nothing)} +@from_rrule DefaultCtx Tuple{typeof(test_nothing)} false # Test case in which ChainRulesCore returns a tangent which is of the "wrong" type from the # perspective of Tapir.jl. In this instance, some kind of error should be thrown, rather @@ -61,7 +63,7 @@ function ChainRulesCore.rrule(::typeof(test_bad_rdata), x::Float64) return 5x, test_bad_rdata_pb end -@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} +@from_rrule DefaultCtx Tuple{typeof(test_bad_rdata), Float64} false # Test case for rule with diagonal dispatch. test_add(x, y) = x + y @@ -69,7 +71,7 @@ function ChainRulesCore.rrule(::typeof(test_add), x, y) test_add_pb(dout) = ChainRulesCore.NoTangent(), dout, dout return x + y, test_add_pb end -@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} +@from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} false # Test case for rule with kwargs. test_kwargs(x; y::Bool) = y ? x : 2x @@ -79,7 +81,7 @@ function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool) return y ? x : 2x, test_kwargs_pb end -@from_rrule DefaultCtx Tuple{typeof(Core.kwcall), NamedTuple, typeof(test_kwargs), Float64} +@from_rrule(DefaultCtx, Tuple{typeof(test_kwargs), Float64}, true) end From 1cfbfcca31b5d49161732c252a40f081343d9e64 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 14:54:41 +0100 Subject: [PATCH 29/39] Default kwargs test --- test/chain_rules_interop.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/chain_rules_interop.jl b/test/chain_rules_interop.jl index bdb10fd1..1f77d54a 100644 --- a/test/chain_rules_interop.jl +++ b/test/chain_rules_interop.jl @@ -74,9 +74,9 @@ end @from_rrule DefaultCtx Tuple{typeof(test_add), T, T} where {T<:IEEEFloat} false # Test case for rule with kwargs. -test_kwargs(x; y::Bool) = y ? x : 2x +test_kwargs(x; y::Bool=false) = y ? x : 2x -function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool) +function ChainRulesCore.rrule(::typeof(test_kwargs), x::Float64; y::Bool=false) test_kwargs_pb(dz::Float64) = ChainRulesCore.NoTangent(), y ? dz : 2dz return y ? x : 2x, test_kwargs_pb end @@ -100,6 +100,7 @@ end (ChainRulesInteropTestResources.test_nothing,), (Core.kwcall, (y=true, ), ChainRulesInteropTestResources.test_kwargs, 5.0), (Core.kwcall, (y=false, ), ChainRulesInteropTestResources.test_kwargs, 5.0), + (ChainRulesInteropTestResources.test_kwargs, 5.0), ] test_rule(sr(1), fargs...; perf_flag=:stability, is_primitive=true) end From 8ac290342d75f86b7ea963f0c4ac52a87adb4fa1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 17 Sep 2024 14:59:38 +0100 Subject: [PATCH 30/39] Improve docstring --- src/chain_rules_interop.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chain_rules_interop.jl b/src/chain_rules_interop.jl index 8beba59d..4b3fc7e7 100644 --- a/src/chain_rules_interop.jl +++ b/src/chain_rules_interop.jl @@ -8,8 +8,10 @@ to_cr_tangent(t::Array{<:IEEEFloat}) = t to_cr_tangent(::NoTangent) = ChainRulesCore.NoTangent() """ - increment_and_get_rdata!(fdata, rdata, cr_tangent) + increment_and_get_rdata!(fdata, zero_rdata, cr_tangent) +Increment `fdata` by the fdata component of the ChainRules.jl-style tangent, `cr_tangent`, +and return the rdata component of `cr_tangent` by adding it to `zero_rdata`. """ increment_and_get_rdata!(::NoFData, r::T, t::T) where {T<:IEEEFloat} = r + t function increment_and_get_rdata!(f::Array{P}, ::NoRData, t::Array{P}) where {P<:IEEEFloat} From ce5afd9e61f692a47cedcdcf7bd256ea65e24f44 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 25 Sep 2024 09:24:00 +0100 Subject: [PATCH 31/39] Some work --- ext/MooncakeLuxLibExt.jl | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 3bf39870..a7ac94e1 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -1,21 +1,30 @@ module MooncakeLuxLibExt - using LuxLib, Random, Mooncake - using Base: IEEEFloat +using LuxLib, Random, Mooncake +using Base: IEEEFloat - import LuxLib.Impl: matmul, matmuladd, fused_dense - import Mooncake: @from_rrule, DefaultCtx +import LuxLib.Impl: matmul, matmuladd, fused_dense +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter - @from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) - @from_rrule( - DefaultCtx, - Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, - ) +@from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) +@from_rrule( + DefaultCtx, + Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, +) + +# Unfused version of `fused_dense`, which `build_rrule` makes use of. +function unfused_dense( + opmode, + act::F, + weight::AbstractMatrix, + x::AbstractMatrix, + b::LuxLib.Optional{<:AbstractVector}, +) where {F} + return bias_activation(act, matmul(opmode, weight, x), b) +end + +function Mooncake.build_rrule(interp::MooncakeInterpreter, sig_or_mi; kwargs...) + return Mooncake.build +end - # The implementations of rrules for fused operations are not straightforward to - # incorporate into Mooncake.jl, because they call back into AD. - # We take a simple appoach to their implementation: differentiate an un-fused version - # of their implementation. This will likely hit performance, but it makes implementing - # rules much more straightforward, in that we only have to be able to implement their - # constituent parts, rather than the entire thing. end From 6edc9a4677e5686605c353397ebe28326820616f Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 30 Sep 2024 18:52:17 +0100 Subject: [PATCH 32/39] Some work --- ext/MooncakeLuxLibExt.jl | 38 +++++++++++++++++----- src/interpreter/abstract_interpretation.jl | 23 +++++++++++-- src/interpreter/ir_utils.jl | 38 +++++++++++++++++----- src/rrules/fastmath.jl | 18 +++++----- test/ext/luxlib.jl | 25 +++++++++++--- test/integration_testing/lux.jl | 32 +++++++++--------- test/runtests.jl | 1 + 7 files changed, 126 insertions(+), 49 deletions(-) diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 3add54e0..5ce8ba8f 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -2,9 +2,10 @@ module MooncakeLuxLibExt using LuxLib, Random, Mooncake using Base: IEEEFloat +using Base.Experimental: @overlay -import LuxLib.Impl: matmul, matmuladd, fused_dense -import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter +import LuxLib.Impl: matmul, matmuladd +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table @from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( @@ -12,19 +13,40 @@ import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, ) -# Unfused version of `fused_dense`, which `build_rrule` makes use of. -function unfused_dense( +# Re-implement a bunch of methods to ensure that Mooncake can differentiate them. +@overlay mooncake_method_table function LuxLib.Impl.fused_dense( opmode, act::F, weight::AbstractMatrix, x::AbstractMatrix, b::LuxLib.Optional{<:AbstractVector}, ) where {F} - return bias_activation(act, matmul(opmode, weight, x), b) + return bias_activation(act, matmul(weight, x), b) end -# function Mooncake.build_rrule(interp::MooncakeInterpreter, sig_or_mi; kwargs...) -# return Mooncake.build -# end +@overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!( + y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector +) where {F, xT, yT} + return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias) +end + +@overlay mooncake_method_table function LuxLib.Impl.activation_loop!( + y::AbstractArray, σ::F, x::AbstractArray +) where {F} + return LuxLib.Impl.activation_simd_loop!(y, σ, x) +end + +@overlay mooncake_method_table function LuxLib.Impl.fused_conv( + ::LuxLib.Impl.AbstractInternalArrayOpMode, + act::F, + weight::AbstractArray{wT, N}, + x::AbstractArray{xT, N}, + bias::LuxLib.Optional{<:AbstractVector}, + cdims::LuxLib.Impl.ConvDims, +) where {F, wT, xT, N} + return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) +end + +# IMPORT SLEEFPirates RULES! Use a loop. end diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index 745ed649..90f71621 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -5,7 +5,6 @@ # The most important bit of this code is `inlining_policy` -- the rest is copy + pasted # boiler plate, largely taken from https://github.com/JuliaLang/julia/blob/2fe4190b3d26b4eee52b2b1b1054ddd6e38a941e/test/compiler/newinterp.jl#L11 - struct ClosureCacheKey world_age::UInt key::Any @@ -17,6 +16,8 @@ end MooncakeCache() = MooncakeCache(IdDict{Core.MethodInstance, Core.CodeInstance}()) +Base.Experimental.@MethodTable mooncake_method_table + struct MooncakeInterpreter{C} <: CC.AbstractInterpreter meta # additional information world::UInt @@ -25,6 +26,7 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult} code_cache::MooncakeCache oc_cache::Dict{ClosureCacheKey, Any} + method_table_to_overlay::CC.MethodTable function MooncakeInterpreter( ::Type{C}; meta=nothing, @@ -34,8 +36,18 @@ struct MooncakeInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::MooncakeCache=MooncakeCache(), oc_cache::Dict{ClosureCacheKey, Any}=Dict{ClosureCacheKey, Any}(), + method_table_to_overlay::CC.MethodTable=mooncake_method_table, ) where {C} - return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache) + return new{C}( + meta, + world, + inf_params, + opt_params, + inf_cache, + code_cache, + oc_cache, + method_table_to_overlay, + ) end end @@ -91,6 +103,9 @@ function CC.setindex!( ) return setindex!(wvc.cache.dict, ci, mi) end +function CC.method_table(interp::MooncakeInterpreter) + return CC.OverlayMethodTable(interp.world, interp.method_table_to_overlay) +end _type(x) = x _type(x::CC.Const) = _typeof(x.val) @@ -108,7 +123,9 @@ function CC.inlining_policy( # Do not inline away primitives. argtype_tuple = Tuple{map(_type, argtypes)...} - is_primitive(C, argtype_tuple) && return nothing + if is_primitive(C, argtype_tuple) + return nothing + end # If not a primitive, AD doesn't care about it. Use the usual inlining strategy. return @invoke CC.inlining_policy( diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index bae39e46..9370f11c 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -170,6 +170,9 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) return ir end +Base.iterate(x::CC.MethodLookupResult) = CC.iterate(x) +Base.iterate(x::CC.MethodLookupResult, n::Int) = CC.iterate(x, n) + """ lookup_ir( interp::AbstractInterpreter, @@ -181,18 +184,35 @@ there is no code found, or if more than one `IRCode` instance returned. Returns a tuple containing the `IRCode` and its return type. """ -function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple}) - output = Base.code_ircode_by_type(sig; interp) - if isempty(output) - throw(ArgumentError("No methods found for signature $sig")) - elseif length(output) > 1 - throw(ArgumentError("$(length(output)) methods found for signature $sig")) +function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_until=nothing) + matches = CC.findall(tt, CC.method_table(interp)) + asts = [] + for match in matches.matches + match = match::Core.MethodMatch + meth = Base.func_for_method_checked(match.method, tt, match.sparams) + (code, ty) = CC.typeinf_ircode( + interp, + meth, + match.spec_types, + match.sparams, + optimize_until, + ) + if code === nothing + push!(asts, match.method => Any) + else + push!(asts, code => ty) + end + end + if isempty(asts) + throw(ArgumentError("No methods found for signature $asts")) + elseif length(asts) > 1 + throw(ArgumentError("$(length(asts)) methods found for signature $sig")) end - return only(output) + return only(asts) end -function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance) - return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, nothing) +function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance; optimize_until=nothing) + return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, optimize_until) end """ diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 93c7e17a..26811f6b 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -1,21 +1,21 @@ -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp_fast(primal(x)) - exp_fast_pb!!(dy::Float64) = NoRData(), dy * yp + exp_fast_pb!!(dy::P) = NoRData(), dy * yp return CoDual(yp, NoFData()), exp_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp2_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp2_fast(primal(x)) - exp2_fast_pb!!(dy::Float64) = NoRData(), dy * yp * log(2) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2) return CoDual(yp, NoFData()), exp2_fast_pb!! end -@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), Float64} -function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{Float64}) +@is_primitive MinimalCtx Tuple{typeof(Base.FastMath.exp10_fast), IEEEFloat} +function rrule!!(::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P}) where {P<:IEEEFloat} yp = Base.FastMath.exp10_fast(primal(x)) - exp2_fast_pb!!(dy::Float64) = NoRData(), dy * yp * log(10) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10) return CoDual(yp, NoFData()), exp2_fast_pb!! end diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index ce0e15c5..46c6396a 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -1,8 +1,25 @@ @testset "luxlib" begin - @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in Any[ - (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), - (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), - ] + @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in vcat( + Any[ + (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), + (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), + ( + false, :none, false, + LuxLib.Impl.activation_loop!, randn(5, 3), NNlib.gelu, randn(5, 3), + ), + ], + vec(map(Iterators.product( + [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + [randn(5), nothing], + [Lux.relu, tanh, NNlib.gelu], + )) do (opmode, bias, activation) + ( + false, :none, false, + LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + ) + end), + ) test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end end diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl index 3f4778e4..e191751b 100644 --- a/test/integration_testing/lux.jl +++ b/test/integration_testing/lux.jl @@ -2,18 +2,18 @@ @testset "$(typeof(f))" for (f, x_f32) in Any[ (Dense(2, 4), randn(Float32, 2, 3)), (Dense(2, 4, gelu), randn(Float32, 2, 3)), - # (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), - # (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), - # (Scale(2), randn(Float32, 2, 3)), - # (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), # uses a task, so has recurrence problem. needs rule - # (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), + (Dense(2, 4, gelu; use_bias=false), randn(Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2, 3)), + (Scale(2), randn(Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), randn(Float32, 3, 3, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule - # (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), + (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), # (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule # (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow # (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow @@ -26,8 +26,8 @@ # (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression # (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression # (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # stack overflow. Probably task again - # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # another task problem + # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # task again # (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again # (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again # (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), @@ -37,8 +37,8 @@ # (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] @info "$(_typeof((f, x_f32...)))" - ps, st = f64(Lux.setup(sr(123456), f)) - x = f64(x_f32) - test_rule(sr(123456), f, x, ps, st; is_primitive=false) + ps, st = f32(Lux.setup(sr(123456), f)) + x = f32(x_f32) + test_rule(sr(123456), f, x, ps, st; is_primitive=false, interface_only=true) end end diff --git a/test/runtests.jl b/test/runtests.jl index 24403095..75d73aa5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ include("front_matter.jl") include(joinpath("ext", "luxlib.jl")) include(joinpath("ext", "nnlib.jl")) include(joinpath("ext", "special_functions.jl")) + include(joinpath("integration_testing", "lux.jl")) elseif test_group == "integration_testing/misc_abstract_array" include(joinpath("integration_testing", "misc_abstract_array.jl")) elseif test_group == "integration_testing/diff_tests" From f66cc9cdb1e5bb45a65657c6e41991fe4587cbe6 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 10:15:37 +0100 Subject: [PATCH 33/39] Better conv support in nnlib rules --- ext/MooncakeNNlibExt.jl | 45 +++++++++++++++++++++++++---------------- test/ext/nnlib.jl | 13 ++++++++++++ 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/ext/MooncakeNNlibExt.jl b/ext/MooncakeNNlibExt.jl index 706a627b..fbd1b2fa 100644 --- a/ext/MooncakeNNlibExt.jl +++ b/ext/MooncakeNNlibExt.jl @@ -4,51 +4,62 @@ module MooncakeNNlibExt using Base: IEEEFloat using NNlib: dropout - import Mooncake: @from_rrule, DefaultCtx + using NNlib: conv, depthwiseconv + import Mooncake: @from_rrule, DefaultCtx, MinimalCtx @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, ) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat}, true, ) - @from_rrule(DefaultCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) - @from_rrule(DefaultCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) - @from_rrule(DefaultCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true) + @from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}}, ) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{ typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims, }, ) @from_rrule( - DefaultCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} + MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims} ) @from_rrule( - DefaultCtx, + MinimalCtx, Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}}, true, ) - for backend in (Symbol(), :_direct, :_im2col), name in (:conv, :depthwiseconv) + for conv in [:conv, :depthwiseconv] + local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter]) + + @eval @from_rrule( + MinimalCtx, + Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) @eval @from_rrule( - DefaultCtx, - Tuple{ - typeof(NNlib.$(Symbol("$name$(backend)"))), Array{P}, Array{P}, ConvDims, - } where {P<:IEEEFloat}, + MinimalCtx, + Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, true, ) end + @eval @from_rrule( + MinimalCtx, + Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat}, + true, + ) for pool in [:maxpool, :meanpool] @eval @from_rrule( - DefaultCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true + MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true ) end - @from_rrule(DefaultCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) + @from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true) end diff --git a/test/ext/nnlib.jl b/test/ext/nnlib.jl index 1f00d9f5..2a3c3fde 100644 --- a/test/ext/nnlib.jl +++ b/test/ext/nnlib.jl @@ -3,6 +3,9 @@ w = randn(2, 2, 3, 3) dense_cdims = DenseConvDims(x, w) sep_cdims = DepthwiseConvDims(x, w) + y = conv(x, w, dense_cdims) + y_sep = depthwiseconv(x, w, sep_cdims) + pool_dims = PoolDims(size(x), 2) grid = Array{Float64}(undef, 2, 2, 2, 1) @@ -75,6 +78,16 @@ (false, :none, true, Core.kwcall, (;), depthwiseconv, x, w, sep_cdims), (false, :none, true, depthwiseconv, x, w, sep_cdims), + # ∇conv_data + (false, :none, true, Core.kwcall, (;), ∇conv_data, y, w, dense_cdims), + (false, :none, true, ∇conv_data, y, w, dense_cdims), + (false, :none, true, Core.kwcall, (;), ∇depthwiseconv_data, y_sep, w, sep_cdims), + (false, :none, true, ∇depthwiseconv_data, y_sep, w, sep_cdims), + + # ∇conv_filter + (false, :none, true, Core.kwcall, (;), ∇conv_filter, x, y, dense_cdims), + (false, :none, true, ∇conv_filter, x, y, dense_cdims), + # pooling (false, :none, true, maxpool, x, pool_dims), (false, :none, true, Core.kwcall, (;), maxpool, x, pool_dims), From f865fde5f0292ca53c29f941cd7491e1d6ff1cdd Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:27:10 +0100 Subject: [PATCH 34/39] More LuxLib rules --- ext/MooncakeLuxLibExt.jl | 136 ++++++++++++++++++++++++++++++-- test/ext/luxlib.jl | 47 ++++++++--- test/front_matter.jl | 1 + test/integration_testing/lux.jl | 45 +++++------ 4 files changed, 191 insertions(+), 38 deletions(-) diff --git a/ext/MooncakeLuxLibExt.jl b/ext/MooncakeLuxLibExt.jl index 5ce8ba8f..1bfa63db 100644 --- a/ext/MooncakeLuxLibExt.jl +++ b/ext/MooncakeLuxLibExt.jl @@ -4,13 +4,18 @@ using LuxLib, Random, Mooncake using Base: IEEEFloat using Base.Experimental: @overlay -import LuxLib.Impl: matmul, matmuladd -import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table +import LuxLib: Impl +import LuxLib.Utils: static_training_mode_check +import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_table, CoDual -@from_rrule(DefaultCtx, Tuple{typeof(matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) +@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat}) @from_rrule( DefaultCtx, - Tuple{typeof(matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, + Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat}, +) +@from_rrule( + DefaultCtx, + Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat}, ) # Re-implement a bunch of methods to ensure that Mooncake can differentiate them. @@ -21,7 +26,7 @@ import Mooncake: @from_rrule, DefaultCtx, MooncakeInterpreter, mooncake_method_t x::AbstractMatrix, b::LuxLib.Optional{<:AbstractVector}, ) where {F} - return bias_activation(act, matmul(weight, x), b) + return bias_activation(act, Impl.matmul(weight, x), b) end @overlay mooncake_method_table function LuxLib.Impl.bias_activation_loop!( @@ -47,6 +52,125 @@ end return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias) end -# IMPORT SLEEFPirates RULES! Use a loop. +for f in [ + Impl.SLEEFActivations.sigmoid_fast, + Impl.SLEEFActivations.softplus, + Impl.SLEEFActivations.logsigmoid, + Impl.SLEEFActivations.swish, + Impl.SLEEFActivations.lisht, + Impl.SLEEFActivations.tanh, + Impl.SLEEFActivations.tanh_fast, +] + @from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat} + @from_rrule( + DefaultCtx, + Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}}, + ) +end + +Mooncake.@is_primitive(DefaultCtx, Tuple{typeof(static_training_mode_check), Vararg}) +function Mooncake.rrule!!(f::CoDual{typeof(static_training_mode_check)}, x::CoDual...) + return Mooncake.simple_zero_adjoint(f, x...) +end + + + + +# This is a really horrible hack that we need to do until Mooncake is able to support the +# call-back-into-ad interface that ChainRules exposes. + +import LuxLib.Impl: + safe_eltype, + batchnorm_affine_normalize_internal, + batchnorm_affine_normalize_internal!, + ∇batchnorm_affine_normalize, + AbstractInternalArrayOpMode + +import ChainRulesCore as CRC + +function CRC.rrule( + ::typeof(batchnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, + ::typeof(identity), + x::AbstractArray{T, N}, + μ::AbstractVector, + σ²::AbstractVector, + γ::LuxLib.Optional{<:AbstractVector}, + β::LuxLib.Optional{<:AbstractVector}, + ϵ::Real, +) where {T, N} + y = similar( + x, + promote_type( + safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) + ) + ) + γ′ = similar( + x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1) + ) + + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) + + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′) + ∂∅ = CRC.NoTangent() + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return y, ∇batchnorm_affine_normalize_internal +end + +@from_rrule( + DefaultCtx, + Tuple{ + typeof(batchnorm_affine_normalize_internal), + AbstractInternalArrayOpMode, + typeof(identity), + AbstractArray, + AbstractVector, + AbstractVector, + LuxLib.Optional{<:AbstractVector}, + LuxLib.Optional{<:AbstractVector}, + Real, + }, +) + +@overlay mooncake_method_table function batchnorm_affine_normalize_internal( + opmode::LuxLib.AbstractInternalArrayOpMode, + act::F, + x::AbstractArray{xT, 3}, + μ::AbstractVector, + σ²::AbstractVector, + γ::Union{Nothing, AbstractVector}, + β::Union{Nothing, AbstractVector}, + ϵ::Real, +) where {F, xT} + y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ) + LuxLib.Impl.activation!(y, opmode, act, y) + return y +end + +@overlay mooncake_method_table function batchnorm_affine_normalize_internal( + opmode::LuxLib.AbstractInternalArrayOpMode, + ::typeof(identity), + x::AbstractArray{xT, 3}, + μ::AbstractVector, + σ²::AbstractVector, + γ::Union{Nothing, AbstractVector}, + β::Union{Nothing, AbstractVector}, + ϵ::Real, +) where {xT} + y = similar(x, + promote_type( + safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β) + ) + ) + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) + return y +end end diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 46c6396a..1befaed8 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -3,22 +3,49 @@ Any[ (false, :none, true, LuxLib.Impl.matmul, randn(5, 4), randn(4, 3)), (false, :none, true, LuxLib.Impl.matmuladd, randn(5, 4), randn(4, 3), randn(5)), + (false, :none, true, LuxLib.Impl.batched_matmul, randn(5, 4, 3), randn(4, 3, 3)), (false, :none, false, LuxLib.Impl.activation, Lux.relu, randn(5, 4)), ( false, :none, false, LuxLib.Impl.activation_loop!, randn(5, 3), NNlib.gelu, randn(5, 3), ), - ], - vec(map(Iterators.product( - [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], - [randn(5), nothing], - [Lux.relu, tanh, NNlib.gelu], - )) do (opmode, bias, activation) + (false, :stability_and_allocs, true, SLEEFActivations.sigmoid_fast, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.softplus, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.logsigmoid, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.swish, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.lisht, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.tanh, randn()), + (false, :stability_and_allocs, true, SLEEFActivations.tanh_fast, randn()), ( - false, :none, false, - LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, - ) - end), + false, :stability_and_allocs, true, + LuxLib.Utils.static_training_mode_check, + nothing, + LuxLib.Utils.True(), + LuxLib.Utils.True(), + ), + ( + false, :none, true, + LuxLib.Impl.batchnorm_affine_normalize_internal, + LuxLib.LoopedArrayOp(), + identity, + randn(5, 4, 3), + randn(4), + ones(4), + nothing, + nothing, + 1.1, + ), + ], + # vec(map(Iterators.product( + # [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + # [randn(5), nothing], + # [Lux.relu, tanh, NNlib.gelu], + # )) do (opmode, bias, activation) + # ( + # false, :none, false, + # LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + # ) + # end), ) test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end diff --git a/test/front_matter.jl b/test/front_matter.jl index 0cf19d69..78fe0a3b 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -24,6 +24,7 @@ using Core: using Core.Intrinsics: pointerref, pointerset using NNlib: dropout +using LuxLib.Impl: SLEEFActivations using Mooncake: CC, diff --git a/test/integration_testing/lux.jl b/test/integration_testing/lux.jl index e191751b..15a90245 100644 --- a/test/integration_testing/lux.jl +++ b/test/integration_testing/lux.jl @@ -12,29 +12,30 @@ (Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), rand(Float32, 5, 5, 2, 2)), (Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), rand(Float32, 5, 5, 2, 2)), (Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 2)), - # (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), # missing intrinsic atomic_pointerref. Also might just need a rule + (Bilinear((2, 2) => 3), randn(Float32, 2, 3)), (SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3)), - # (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), # uses a task, so has recurrence problem. needs rule - # (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), # fpext getting used here somehow - # (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), # fpext getting used here somehow - # (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), # fpext getting used here somehow - # (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), # fpext getting used here somehow - # (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), # something about begin_optional -- this is an unrecognised expression - # (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # another task problem - # (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), # task again - # (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), # fpext again - # (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), # fpext again - # (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), - # (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), - # (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), - # (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), - # (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(Float32, 5, 5, 3, 1)), + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(RNNCell(3 => 5)), StatefulRecurrentCell(RNNCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(Float32, 3, 2)), + (Chain(StatefulRecurrentCell(LSTMCell(3 => 5)), StatefulRecurrentCell(LSTMCell(5 => 3))), rand(Float32, 3, 2)), + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(Float32, 3, 10)), + (Chain(StatefulRecurrentCell(GRUCell(3 => 5)), StatefulRecurrentCell(GRUCell(5 => 3))), rand(Float32, 3, 10)), + (Chain(Dense(2, 4), BatchNorm(4)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), randn(Float32, 4, 4, 2, 2)), + (InstanceNorm(6), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] @info "$(_typeof((f, x_f32...)))" ps, st = f32(Lux.setup(sr(123456), f)) From 149e7b4d1941496fc70bd455c19da39e403094f2 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:27:24 +0100 Subject: [PATCH 35/39] Permit :meta nodes in IR --- src/interpreter/s2s_reverse_mode_ad.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 039d94a0..472c8c0d 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -617,6 +617,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) :leave, :pop_exception, :throw_undef_if_not, + :meta, ] # Expressions which do not require any special treatment. return ad_stmt_info(line, nothing, stmt, nothing) From 2dcd5350496e87225ba486cf9fbaa62065a8baa0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:35:01 +0100 Subject: [PATCH 36/39] Remove redundant test --- test/ext/luxlib.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 1befaed8..5edc9610 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -23,18 +23,6 @@ LuxLib.Utils.True(), LuxLib.Utils.True(), ), - ( - false, :none, true, - LuxLib.Impl.batchnorm_affine_normalize_internal, - LuxLib.LoopedArrayOp(), - identity, - randn(5, 4, 3), - randn(4), - ones(4), - nothing, - nothing, - 1.1, - ), ], # vec(map(Iterators.product( # [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], From 0933f37f61ccacd8baebf233a2dc23e53af5ebd7 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:46:09 +0100 Subject: [PATCH 37/39] Uncomment some tests --- test/ext/luxlib.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/ext/luxlib.jl b/test/ext/luxlib.jl index 5edc9610..1748d37f 100644 --- a/test/ext/luxlib.jl +++ b/test/ext/luxlib.jl @@ -24,16 +24,16 @@ LuxLib.Utils.True(), ), ], - # vec(map(Iterators.product( - # [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], - # [randn(5), nothing], - # [Lux.relu, tanh, NNlib.gelu], - # )) do (opmode, bias, activation) - # ( - # false, :none, false, - # LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, - # ) - # end), + vec(map(Iterators.product( + [LuxLib.LoopedArrayOp(), LuxLib.GenericBroadcastOp()], + [randn(5), nothing], + [Lux.relu, tanh, NNlib.gelu], + )) do (opmode, bias, activation) + ( + false, :none, false, + LuxLib.Impl.fused_dense, opmode, activation, randn(5, 4), randn(4, 2), bias, + ) + end), ) test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) end From d217102f18a9a355fc47c668d0ba8c6d05ef6382 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 13:52:42 +0100 Subject: [PATCH 38/39] Rename chain rules doc --- docs/make.jl | 2 +- docs/src/{using_chain_rules.md => tools_for_rules.md} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename docs/src/{using_chain_rules.md => tools_for_rules.md} (100%) diff --git a/docs/make.jl b/docs/make.jl index 1a7d1179..f23ec609 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -33,7 +33,7 @@ makedocs( "Mooncake.jl's Rule System" => "mathematical_interpretation.md", ], "Utilities" => [ - "Using ChainRules" => "using_chain_rules.md", + "Tools for Rules" => "tools_for_rules.md", "Debug Mode" => "debug_mode.md", "Debugging and MWEs" => "debugging_and_mwes.md", ], diff --git a/docs/src/using_chain_rules.md b/docs/src/tools_for_rules.md similarity index 100% rename from docs/src/using_chain_rules.md rename to docs/src/tools_for_rules.md From c6f8cf01518b23d6a7328b26f587d01413f69a7e Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 1 Oct 2024 14:29:46 +0100 Subject: [PATCH 39/39] Add notes to docs on rule writing strategies --- docs/src/tools_for_rules.md | 39 ++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/src/tools_for_rules.md b/docs/src/tools_for_rules.md index a7a07709..a3740e99 100644 --- a/docs/src/tools_for_rules.md +++ b/docs/src/tools_for_rules.md @@ -1,4 +1,41 @@ -# Using ChainRules.jl +# Tools for Rules + +Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. +However, this does not always necessitate writing your own `rrule!!` from scratch. +In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations. + +## Simplfiying Code via Overlays + +Suppose you have a function +```julia +foo(x::Float64) = bar(x) +``` +where Mooncake.jl fails to differentiate `bar` for some reason. +If you have access to another function `baz`, which does the same thing as `bar`, but does so in a way which Mooncake.jl can differentiate, you can simply write: +```julia +Base.Experimental.@overlay Mooncake.mooncake_method_table foo(x::Float64) = baz(x) +``` +When looking up the code for `foo(::Float64)`, Mooncake.jl will see this method, rather than the original, and should successfully differentiate it. +If you search for `@overlay` in the Mooncake.jl source code, you will see a variety of instances where this is used in practice. + +This approach is often very straightforward, and we recommend you try this first before going down the path of writing rules. + +## Functions with Zero Derivative + +If the above strategy does not work, but you find yourself in the surprisingly common situation that the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following: +```@docs +Mooncake.simple_zero_adjoint +``` +Suppose you have a function `foo(x, y, z)` whose derivative is zero, you would write an `rrule!!` as follows: +```julia +function Mooncake.rrule!!(f::CoDual{typeof(foo)}, x::CoDual, y::CoDual, z::CoDual) + return Mooncake.simple_zero_adjoint(f, x, y, z) +end +``` +Users of ChainRules.jl should be familiar with this functionality -- it is morally the same as `ChainRulesCore.@non_differentiable`. +This approach is utilised often in Mooncake.jl's codebase. + +## Using ChainRules.jl [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the `ChainRulesCore.rrule` function.