diff --git a/Project.toml b/Project.toml index bf474dfe..5de9de4e 100644 --- a/Project.toml +++ b/Project.toml @@ -44,6 +44,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"] AMDGPU = "0.9.6" Aqua = "0.8.7" ArrayInterface = "7.9" +BFloat16s = "0.5.0" CUDA = "5.3.2" ChainRulesCore = "1.24" ComponentArrays = "0.15.16" @@ -86,6 +87,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" @@ -104,4 +106,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] +test = ["Aqua", "BFloat16s", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"] diff --git a/test/common_ops/activation_tests.jl b/test/common_ops/activation_tests.jl index 803abee5..56ec529b 100644 --- a/test/common_ops/activation_tests.jl +++ b/test/common_ops/activation_tests.jl @@ -8,7 +8,7 @@ @testset "$mode" for (mode, aType, ongpu) in MODES @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, logsigmoid, gelu, swish, lisht, tanh, tanh_fast], - T in [Float16, Float32, Float64] + T in [BFloat16, Float32, Float64] x = rand(rng, T, 4, 3) |> aType @@ -16,9 +16,8 @@ y2 = apply_act_fast(f, x) y3 = apply_act_fast2(f, x) - fp16 = T == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y1≈y2 atol=atol rtol=rtol @test y1≈y3 atol=atol rtol=rtol diff --git a/test/common_ops/conv_tests.jl b/test/common_ops/conv_tests.jl index abdcb6f3..815c6744 100644 --- a/test/common_ops/conv_tests.jl +++ b/test/common_ops/conv_tests.jl @@ -1,5 +1,5 @@ @testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s _expand(N, i::Tuple) = i _expand(N, i::Integer) = ntuple(_ -> i, N) @@ -28,9 +28,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims) - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 # Operation reordering has an effect on the accuracy of the results @test y≈y_generic atol=atol rtol=rtol @test eltype(y) == promote_type(Tw, Tx) @@ -61,14 +60,13 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, mp && push!(skip_backends, AutoReverseDiff()) ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && push!(skip_backends, AutoTracker()) - test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, - soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) + test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends) end anonact = x -> gelu(x) -const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)] +const ELTYPES = [(BFloat16, BFloat16), (Float32, BFloat16), + (Float32, Float32), (Float32, Float64), (Float64, Float64)] const ACTIVATIONS = [ identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index b2a0f065..106307bf 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -1,5 +1,5 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s anonact = x -> x^3 @@ -25,24 +25,21 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode @test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true end - fp16 = Tx == Float16 || Tw == Float16 - atol = fp16 ? 1.0f-1 : 1.0f-3 - rtol = fp16 ? 1.0f-1 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 skip_backends = [] Tw != Tx && push!(skip_backends, AutoReverseDiff()) - fp16 && push!(skip_backends, AutoFiniteDiff()) __f_grad = let activation = activation (w, x, b) -> __f(activation, w, x, b) end - test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, - soft_fail=(fp16 ? [AutoFiniteDiff()] : [])) + test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends) end const ALL_TEST_CONFIGS = Iterators.product( - ((Float16, Float16), (Float32, Float16), (Float32, Float32), - (Float32, Float64), (Float64, Float64)), + ((BFloat16, BFloat16), (Float32, BFloat16), + (Float32, Float32), (Float32, Float64), (Float64, Float64)), (4, 8), (4, 8), (true, false), diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index 015227b8..9042290b 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -2,7 +2,7 @@ rng = StableRNG(12345) @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + @testset "$T: $x_shape" for T in (BFloat16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) x = randn(rng, T, x_shape) |> aType @@ -26,9 +26,7 @@ __f = let rng = rng, T = T x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon()))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon()) @@ -48,7 +46,7 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + @testset "$T: $x_shape" for T in (BFloat16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) x = randn(rng, T, x_shape) |> aType @@ -76,9 +74,7 @@ end x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon()))) @@ -106,9 +102,7 @@ end x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -137,9 +131,7 @@ end x -> sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) @jet sum(first(dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) @@ -165,7 +157,7 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, ongpu) in MODES - @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + @testset "$T: $x_shape" for T in (BFloat16, Float32, Float64), x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) x = randn(rng, T, x_shape) |> aType @@ -186,9 +178,7 @@ end __f = let rng = rng x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any diff --git a/test/normalization/batchnorm_tests.jl b/test/normalization/batchnorm_tests.jl index 5735f6ac..bcecb4b1 100644 --- a/test/normalization/batchnorm_tests.jl +++ b/test/normalization/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -41,9 +41,8 @@ function run_batchnorm_testing( y_simple, nt_simple = __batchnorm_basic( x, scale, bias, rm, rv, training, act, T(0.9), epsilon) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y≈y_simple atol=atol rtol=rtol if track_stats @@ -82,22 +81,9 @@ function run_batchnorm_testing( skip_backends = [] act === relu && push!(skip_backends, AutoFiniteDiff()) - soft_fail = if fp16 - if Sys.iswindows() - [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] - else - true - end - else - false - end - - broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] - __f = (args...) -> sum(first(batchnorm( args..., rm, rv, training, act, T(0.9), epsilon))) - test_gradients( - __f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends) + test_gradients(__f, x, scale, bias; atol, rtol, skip_backends) end if anonact !== act @@ -109,7 +95,7 @@ function run_batchnorm_testing( end const ALL_TEST_CONFIGS = Iterators.product( - [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + [BFloat16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), (Val(true), Val(false)), (true, false), (true, false), (identity, relu, tanh_fast, sigmoid_fast, anonact)) diff --git a/test/normalization/groupnorm_tests.jl b/test/normalization/groupnorm_tests.jl index 86363c5a..c532a130 100644 --- a/test/normalization/groupnorm_tests.jl +++ b/test/normalization/groupnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s function _setup_groupnorm(gen_f, aType, T, sz) x = gen_f(T, sz) |> aType @@ -34,20 +34,17 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) y_simple = _f2(x, scale, bias) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y≈y_simple atol=atol rtol=rtol # Check the rrules - if !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) @@ -61,11 +58,11 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu) @test size(y) == sz __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) - soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + soft_fail = [AutoFiniteDiff()] test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end -const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], +const ALL_TEST_CONFIGS = Iterators.product([BFloat16, Float32, Float64], ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), (2, 3), diff --git a/test/normalization/instancenorm_tests.jl b/test/normalization/instancenorm_tests.jl index 4eb585a2..68a8495b 100644 --- a/test/normalization/instancenorm_tests.jl +++ b/test/normalization/instancenorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s __is_training(::Val{training}) where {training} = training @@ -21,20 +21,17 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y≈y_simple atol=atol rtol=rtol # Check the rrules - if !fp16 - ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) - ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) - @test ∂x≈∂x_simple atol=atol rtol=rtol - @test ∂scale≈∂scale_simple atol=atol rtol=rtol - @test ∂bias≈∂bias_simple atol=atol rtol=rtol - end + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) @@ -49,13 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp if __is_training(training) __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) - soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + soft_fail = [AutoFiniteDiff()] test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) end end const ALL_TEST_CONFIGS = Iterators.product( - [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + [BFloat16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) const TEST_BLOCKS = collect(Iterators.partition( diff --git a/test/normalization/layernorm_tests.jl b/test/normalization/layernorm_tests.jl index fe665893..327fec86 100644 --- a/test/normalization/layernorm_tests.jl +++ b/test/normalization/layernorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module LayerNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, BFloat16s using LuxTestUtils: check_approx function _setup_layernorm(gen_f, aType, T, x_size, affine_shape) @@ -33,11 +33,10 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) end - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 - soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + soft_fail = [AutoFiniteDiff()] if affine_shape !== nothing __f = (args...) -> sum(_f(args...)) test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) @@ -56,7 +55,7 @@ anonact = x -> x^3 const ALL_TEST_CONFIGS = Any[] -for T in (Float16, Float32, Float64), +for T in (BFloat16, Float32, Float64), x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), act in (identity, relu, tanh_fast, sigmoid_fast, anonact) diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 9c43bd31..11d9c17a 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices -@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote +@reexport using BFloat16s, LuxTestUtils, StableRNGs, Test, Enzyme, Zygote LuxTestUtils.jet_target_modules!(["LuxLib"])