Skip to content

Commit

Permalink
test: set st to training
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent 8144727 commit 9a5cc21
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/layers/normalize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@
@jet m(x, ps, st)

if affine
__f = (x, ps) -> sum(first(m(x, ps, st)))
st_train = Lux.trainmode(st)
__f = (x, ps) -> sum(first(m(x, ps, st_train)))
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true
else
__f = x -> sum(first(m(x, ps, st)))
st_train = Lux.trainmode(st)
__f = x -> sum(first(m(x, ps, st_train)))
@eval @test_gradients $__f $x atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true
end

Expand Down

1 comment on commit 9a5cc21

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 9a5cc21 Previous: d2e25f0 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3689.375 ns 3679.375 ns 1.00
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7108.25 ns 7148.333333333333 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20729 ns 20979.5 ns 0.99
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9788.2 ns 9730.2 ns 1.01
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8803 ns 8960.8 ns 0.98
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4445.75 ns 4434.5 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1157.5942028985507 ns 1164.5579710144928 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1115.3624161073826 ns 1111.753164556962 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1179.5149253731342 ns 1191.9761904761904 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1789.469387755102 ns 1783 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.56022408963585 ns 180.34978843441468 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17263 ns 17352 ns 0.99
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16761 ns 16781 ns 1.00
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 37309 ns 37459 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29264 ns 28894 ns 1.01
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19877 ns 20018 ns 0.99
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17122 ns 17342 ns 0.99
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4326.571428571428 ns 4352.285714285715 ns 0.99
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3858.375 ns 3884.75 ns 0.99
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3932.25 ns 3976.125 ns 0.99
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4983.571428571428 ns 4942.142857142857 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1657 ns 1658.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 49278699 ns 46588233 ns 1.06
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57576751.5 ns 58353734.5 ns 0.99
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 109160869.5 ns 94269644 ns 1.16
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 106447718.5 ns 102950471.5 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 91530370 ns 96068744 ns 0.95
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11675389.5 ns 12047387.5 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 6931919 ns 7142740 ns 0.97
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7098291 ns 7293339.5 ns 0.97
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7022403 ns 7061588 ns 0.99
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 18360298 ns 18208981 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6385781 ns 6409324 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 705394593 ns 688972274 ns 1.02
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2616824850 ns 2539093110 ns 1.03
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 144841156.5 ns 144000135 ns 1.01
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 804574120 ns 863231618 ns 0.93
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3328481757 ns 3262428068 ns 1.02
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 209149562 ns 191155096 ns 1.09
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 773576256 ns 789348108 ns 0.98
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2806385145 ns 2900294256 ns 0.97
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 147363592 ns 138816084 ns 1.06
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 174721051 ns 195622452 ns 0.89
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 654809349.5 ns 654350370 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 34172290.5 ns 35742014 ns 0.96
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 165141726 ns 165293052.5 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 643839250 ns 638681895 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30005067.5 ns 30105238 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 231201192 ns 200923901 ns 1.15
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 829336756 ns 776895072 ns 1.07
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 37372900 ns 40449323 ns 0.92
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1249505098.5 ns 1261619764.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1858491430 ns 1870728099.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2371943763 ns 2436576396 ns 0.97
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2553087977 ns 2587480428 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1962392226 ns 1919797321 ns 1.02
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 322990127.5 ns 330099395 ns 0.98
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 319576885 ns 331178567 ns 0.96
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 317940427 ns 330139867 ns 0.96
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 389407449 ns 370302896 ns 1.05
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11696482 ns 11794641 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17828357 ns 18029221 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19008271 ns 19208246.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23738072.5 ns 23923015 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17857580.5 ns 18021872 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1163588 ns 1173398.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2064722 ns 2068852.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2070598 ns 2099161 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2074535 ns 2097189 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2063896 ns 2079116.5 ns 0.99
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 198094 ns 204831.5 ns 0.97
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 292194 ns 293185 ns 1.00
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 264052 ns 264543 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 364308 ns 363617 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 407148 ns 406907 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 273069 ns 273449 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 411968 ns 409753 ns 1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83345 ns 83506 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81101 ns 81021 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 81492 ns 81942 ns 0.99
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86582 ns 86481 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104545 ns 104654 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 196142686.5 ns 193361558 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 325153569 ns 327463711.5 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 446916300.5 ns 441012061 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 496223370 ns 500375685 ns 0.99
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 421390779 ns 417608052 ns 1.01
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 322878155 ns 330500815 ns 0.98
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 44140967 ns 44999058 ns 0.98
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44144811 ns 45171381 ns 0.98
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43902323 ns 44162742 ns 0.99
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 59027307 ns 68647265.5 ns 0.86
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28155959 ns 28099293.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18833149 ns 18960314 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19487768.5 ns 19668528 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23275244 ns 23628735 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24092658 ns 24335108.5 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19580559 ns 19734998 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6501763.5 ns 6563782 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6519016 ns 6579502 ns 0.99
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6495883 ns 6590071.5 ns 0.99
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6478856.5 ns 6582798 ns 0.98

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.