Skip to content

Commit

Permalink
fix: qa testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent e858e21 commit 8144727
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Non Differentiable Functions
CRC.@non_differentiable replicate(::Any)
CRC.@non_differentiable replicate(::Any) # TODO: move to LuxCore.jl
CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)
CRC.@non_differentiable istraining(::Any)
CRC.@non_differentiable _get_norm_except_dims(::Any, ::Any)
CRC.@non_differentiable _affine(::Any)
CRC.@non_differentiable _track_stats(::Any)
CRC.@non_differentiable _conv_transpose_dims(::Any...)
CRC.@non_differentiable _calc_padding(::Any...)
CRC.@non_differentiable Base.printstyled(::Any...)
CRC.@non_differentiable fieldcount(::Any) ## Type Piracy: Needs upstreaming
CRC.@non_differentiable Base.printstyled(::Any...) # TODO: Move to ChainRules.jl
CRC.@non_differentiable fieldcount(::Any) # TODO: Move to ChainRules.jl
CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any)
CRC.@non_differentiable __set_refval!(::Any...)
CRC.@non_differentiable __state_if_stateful(::Any)
Expand Down
5 changes: 2 additions & 3 deletions src/forwarddiff/nested_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ for type in (:Gradient, :Jacobian)
end

rrule_call = if type == :Gradient
:((res, pb_f) = CRC.rrule_via_ad(
cfg, Lux.__internal_ad_gradient_call, grad_fn, f, x, y))
:((res, pb_f) = CRC.rrule_via_ad(cfg, __internal_ad_gradient_call, grad_fn, f, x, y))
else
:((res, pb_f) = CRC.rrule_via_ad(
cfg, Lux.__internal_ad_jacobian_call, ForwardDiff.$(fname), grad_fn, f, x, y))
cfg, __internal_ad_jacobian_call, ForwardDiff.$(fname), grad_fn, f, x, y))
end
ret_expr = type == :Gradient ? :(only(res)) : :(res)
@eval begin
Expand Down
6 changes: 4 additions & 2 deletions test/qa_tests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
@testitem "Aqua: Quality Assurance" tags=[:others] begin
using Aqua, ChainRulesCore
using Aqua, ChainRulesCore, ForwardDiff

Aqua.test_all(Lux; piracies=false, ambiguities=false)
Aqua.test_ambiguities(Lux; recursive=false)
Aqua.test_ambiguities(Lux;
exclude=[ForwardDiff.jacobian, ForwardDiff.gradient,
Lux.__batched_jacobian, Lux.__jacobian_vector_product_impl])
Aqua.test_piracies(
Lux; treat_as_own=[ChainRulesCore.frule, ChainRulesCore.rrule, Core.kwcall])
end
Expand Down

1 comment on commit 8144727

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 8144727 Previous: b5a412f Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3669.5 ns 3877.25 ns 0.95
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7158.5 ns 7136.666666666667 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 21270 ns 20909 ns 1.02
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9868.6 ns 9758.4 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9075.2 ns 9005.625 ns 1.01
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4494.75 ns 4487.125 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1158.1323529411766 ns 1159.4113475177305 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1108.3333333333333 ns 1118.1447368421052 ns 0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1166.3758865248226 ns 1184.7333333333333 ns 0.98
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1783.5263157894738 ns 1787.7936507936508 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.28349788434414 ns 179.43949930458973 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17453 ns 17263 ns 1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 17072 ns 16751 ns 1.02
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37861 ns 37691 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28624 ns 29346 ns 0.98
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20098 ns 21570 ns 0.93
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17413 ns 17322.5 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4321 ns 4315.285714285715 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3873.625 ns 3852.25 ns 1.01
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3953.625 ns 3968.875 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4959.428571428572 ns 4945.071428571428 ns 1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1653.1 ns 1659.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38465119 ns 39206882 ns 0.98
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57646876 ns 57767448 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 75865022 ns 76005484.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88423185.5 ns 88733451 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72301851 ns 72624222 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11688195 ns 11618229 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 6947657.5 ns 6948329 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7114060 ns 7100361 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7046153 ns 7045120 ns 1.00
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 9919070 ns 10499396 ns 0.94
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6383236 ns 6388556 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 700307873 ns 696069670 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2573675425 ns 2560382099 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 146273034.5 ns 146896144 ns 1.00
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 766244327 ns 753267556 ns 1.02
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2946277581 ns 3219866614 ns 0.92
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 186150849.5 ns 189647046.5 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 683705070 ns 650388253.5 ns 1.05
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2646332482 ns 2641583439 ns 1.00
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 137195152 ns 126045986 ns 1.09
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 175100769.5 ns 175012423.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 657088055 ns 654609813.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34808952 ns 45650343.5 ns 0.76
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 173079830 ns 165489868 ns 1.05
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 647810031.5 ns 645213518 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30215970.5 ns 30449566 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 186723548 ns 186588697.5 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 716777163.5 ns 762585936.5 ns 0.94
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 35812469 ns 35592426 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1201776757.5 ns 1244392513.5 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1867312418 ns 1860967941.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2252712336 ns 2412383381 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2534981235 ns 2508398406 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1849417398.5 ns 1841004113 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 319942263 ns 325387747 ns 0.98
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 322876057 ns 321376487 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 322189133 ns 320382227 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 350507999 ns 351110878.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11709496.5 ns 11916214.5 ns 0.98
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17775143 ns 17847044.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19047551 ns 18999477 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23749212 ns 23832058 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17809947 ns 17900491 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1166579 ns 1156388 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2069792 ns 2068406 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2075303 ns 2074848 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2083859 ns 2080966 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2070143 ns 2064530 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 198384 ns 201207 ns 0.99
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293351.5 ns 291927 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 266071 ns 264766 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 372531 ns 367819.5 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 411264 ns 407093 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 274837 ns 274194 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 408429 ns 413023.5 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83357 ns 83176 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81022 ns 80792 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81623 ns 81824 ns 1.00
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86773 ns 86422 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104757 ns 104516 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 185710969 ns 189753594 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 321911456 ns 326290151.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 392672252 ns 389674161 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 459161840.5 ns 463032519.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 370778856 ns 372310812 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 326381332.5 ns 321094560 ns 1.02
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 44291164.5 ns 44247080 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44361965.5 ns 44274719.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 44088654 ns 43910438 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 50059502 ns 60159984.5 ns 0.83
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28564487 ns 27695071 ns 1.03
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19023815 ns 19050107 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19619630 ns 19567246 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23614610.5 ns 23330177 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24290489.5 ns 24116204 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19710266.5 ns 19653615 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6516122 ns 6510513.5 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6545687 ns 6528171 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6517253 ns 6491807 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6496415 ns 6491543.5 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.