Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: swap Float16 tests with BFloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 2, 2024
1 parent bc0131d commit 3995555
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 92 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ LuxLibcuDNNExt = ["CUDA", "cuDNN"]
AMDGPU = "0.9.6"
Aqua = "0.8.7"
ArrayInterface = "7.9"
BFloat16s = "0.5.0"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.16"
Expand Down Expand Up @@ -86,6 +87,7 @@ julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Expand All @@ -104,4 +106,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"]
test = ["Aqua", "BFloat16s", "ComponentArrays", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "JLArrays", "LuxTestUtils", "Pkg", "Preferences", "ReTestItems", "ReverseDiff", "StableRNGs", "StaticArrays", "Test", "Tracker", "Zygote"]
7 changes: 3 additions & 4 deletions test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus,
logsigmoid, gelu, swish, lisht, tanh, tanh_fast],
T in [Float16, Float32, Float64]
T in [BFloat16, Float32, Float64]

x = rand(rng, T, 4, 3) |> aType

y1 = apply_act(f, x)
y2 = apply_act_fast(f, x)
y3 = apply_act_fast2(f, x)

fp16 = T == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y1y2 atol=atol rtol=rtol
@test y1y3 atol=atol rtol=rtol
Expand Down
14 changes: 6 additions & 8 deletions test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module ConvSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

_expand(N, i::Tuple) = i
_expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -28,9 +28,8 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

y_generic = LuxLib._generic_conv_bias_activation(activation, weight, x, bias, cdims)

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3
# Operation reordering has an effect on the accuracy of the results
@test yy_generic atol=atol rtol=rtol
@test eltype(y) == promote_type(Tw, Tx)
Expand Down Expand Up @@ -61,14 +60,13 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
mp && push!(skip_backends, AutoReverseDiff())
((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) &&
push!(skip_backends, AutoTracker())
test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends,
soft_fail=(fp16 ? [AutoFiniteDiff()] : []))
test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends)
end

anonact = x -> gelu(x)

const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)]
const ELTYPES = [(BFloat16, BFloat16), (Float32, BFloat16),
(Float32, Float32), (Float32, Float64), (Float64, Float64)]
const ACTIVATIONS = [
identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact]

Expand Down
15 changes: 6 additions & 9 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module DenseSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

anonact = x -> x^3

Expand All @@ -25,24 +25,21 @@ function run_dense_testing(gen_f, Tw, Tx, M, N, hasbias, activation, aType, mode
@test length(@inferred(Zygote.gradient(__f, activation, w, x, bias)))==4 broken=true
end

fp16 = Tx == Float16 || Tw == Float16
atol = fp16 ? 1.0f-1 : 1.0f-3
rtol = fp16 ? 1.0f-1 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

skip_backends = []
Tw != Tx && push!(skip_backends, AutoReverseDiff())
fp16 && push!(skip_backends, AutoFiniteDiff())

__f_grad = let activation = activation
(w, x, b) -> __f(activation, w, x, b)
end
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends,
soft_fail=(fp16 ? [AutoFiniteDiff()] : []))
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends)
end

const ALL_TEST_CONFIGS = Iterators.product(
((Float16, Float16), (Float32, Float16), (Float32, Float32),
(Float32, Float64), (Float64, Float64)),
((BFloat16, BFloat16), (Float32, BFloat16),
(Float32, Float32), (Float32, Float64), (Float64, Float64)),
(4, 8),
(4, 8),
(true, false),
Expand Down
26 changes: 8 additions & 18 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

x = randn(rng, T, x_shape) |> aType
Expand All @@ -26,9 +26,7 @@
__f = let rng = rng, T = T
x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), Colon())

Expand All @@ -48,7 +46,7 @@ end
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

x = randn(rng, T, x_shape) |> aType
Expand Down Expand Up @@ -76,9 +74,7 @@ end
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), Colon())))
Expand Down Expand Up @@ -106,9 +102,7 @@ end
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
Expand Down Expand Up @@ -137,9 +131,7 @@ end
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
Expand All @@ -165,7 +157,7 @@ end
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, ongpu) in MODES
@testset "$T: $x_shape" for T in (Float16, Float32, Float64),
@testset "$T: $x_shape" for T in (BFloat16, Float32, Float64),
x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1))

x = randn(rng, T, x_shape) |> aType
Expand All @@ -186,9 +178,7 @@ end
__f = let rng = rng
x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any
Expand Down
24 changes: 5 additions & 19 deletions test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module BatchNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

function _setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -41,9 +41,8 @@ function run_batchnorm_testing(
y_simple, nt_simple = __batchnorm_basic(
x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test yy_simple atol=atol rtol=rtol
if track_stats
Expand Down Expand Up @@ -82,22 +81,9 @@ function run_batchnorm_testing(
skip_backends = []
act === relu && push!(skip_backends, AutoFiniteDiff())

soft_fail = if fp16
if Sys.iswindows()
[AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()]
else
true
end
else
false
end

broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : []

__f = (args...) -> sum(first(batchnorm(
args..., rm, rv, training, act, T(0.9), epsilon)))
test_gradients(
__f, x, scale, bias; atol, rtol, skip_backends, soft_fail, broken_backends)
test_gradients(__f, x, scale, bias; atol, rtol, skip_backends)
end

if anonact !== act
Expand All @@ -109,7 +95,7 @@ function run_batchnorm_testing(
end

const ALL_TEST_CONFIGS = Iterators.product(
[Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
[BFloat16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (true, false), (true, false),
(identity, relu, tanh_fast, sigmoid_fast, anonact))

Expand Down
23 changes: 10 additions & 13 deletions test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module GroupNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

function _setup_groupnorm(gen_f, aType, T, sz)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -34,20 +34,17 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu)

y_simple = _f2(x, scale, bias)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test yy_simple atol=atol rtol=rtol

# Check the rrules
if !fp16
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f2, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol
end
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f2, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol

@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@jet groupnorm(x, scale, bias, groups, act, epsilon)
Expand All @@ -61,11 +58,11 @@ function run_groupnorm_testing(gen_f, T, sz, groups, act, aType, mode, ongpu)
@test size(y) == sz

__f = (args...) -> sum(groupnorm(args..., groups, act, epsilon))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
soft_fail = [AutoFiniteDiff()]
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
end

const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64],
const ALL_TEST_CONFIGS = Iterators.product([BFloat16, Float32, Float64],
((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2),
(4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)),
(2, 3),
Expand Down
23 changes: 10 additions & 13 deletions test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module InstanceNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, BFloat16s

__is_training(::Val{training}) where {training} = training

Expand All @@ -21,20 +21,17 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp

y_simple, nt_simple = instancenorm(x, scale, bias, training, act, epsilon)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test yy_simple atol=atol rtol=rtol

# Check the rrules
if !fp16
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol
end
∂x, ∂scale, ∂bias = Zygote.gradient(sum _f, x, scale, bias)
∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum _f, x, scale, bias)
@test ∂x∂x_simple atol=atol rtol=rtol
@test ∂scale∂scale_simple atol=atol rtol=rtol
@test ∂bias∂bias_simple atol=atol rtol=rtol

@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)
Expand All @@ -49,13 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongp

if __is_training(training)
__f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon)))
soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
soft_fail = [AutoFiniteDiff()]
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
end
end

const ALL_TEST_CONFIGS = Iterators.product(
[Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
[BFloat16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)),
(Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact))

const TEST_BLOCKS = collect(Iterators.partition(
Expand Down
11 changes: 5 additions & 6 deletions test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module LayerNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, BFloat16s
using LuxTestUtils: check_approx

function _setup_layernorm(gen_f, aType, T, x_size, affine_shape)
Expand Down Expand Up @@ -33,11 +33,10 @@ function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu
@test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1)
end

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

soft_fail = fp16 ? fp16 : [AutoFiniteDiff()]
soft_fail = [AutoFiniteDiff()]
if affine_shape !== nothing
__f = (args...) -> sum(_f(args...))
test_gradients(__f, x, scale, bias; atol, rtol, soft_fail)
Expand All @@ -56,7 +55,7 @@ anonact = x -> x^3

const ALL_TEST_CONFIGS = Any[]

for T in (Float16, Float32, Float64),
for T in (BFloat16, Float32, Float64),
x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)),
affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])),
act in (identity, relu, tanh_fast, sigmoid_fast, anonact)
Expand Down
2 changes: 1 addition & 1 deletion test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import Reexport: @reexport

using LuxLib, MLDataDevices
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote
@reexport using BFloat16s, LuxTestUtils, StableRNGs, Test, Enzyme, Zygote

LuxTestUtils.jet_target_modules!(["LuxLib"])

Expand Down

0 comments on commit 3995555

Please sign in to comment.