Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: skip samplers that don't support FP64
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 3, 2024
1 parent 59ce7d6 commit 4e62456
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"]
WeightInitializersCUDAExt = ["CUDA", "GPUArrays"]
WeightInitializersGPUArraysExt = "GPUArrays"
WeightInitializersMetalExt = ["Metal", "GPUArrays"]
WeightInitializersOneAPIExt = ["oneAPI", "GPUArrays"]
WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]

[compat]
AMDGPU = "0.9.6"
Expand Down
2 changes: 1 addition & 1 deletion ext/WeightInitializersGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ for f in (:__rand, :__randn)
rng::RNG, ::Type{<:Complex{T}}, args...) where {T <: Number}
real_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
return Complex.(real_part, imag_part)
return Complex{T}.(real_part, imag_part)
end
end

Expand Down
2 changes: 1 addition & 1 deletion ext/WeightInitializersoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module WeightInitializersoneAPIExt

using oneAPI: oneArray
using oneAPI: oneAPI, oneArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
Expand Down
4 changes: 2 additions & 2 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(
u = _norm_cdf((T(hi) - T(mean)) / T(std))
xs = __rand(rng, T, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = x * 2(u - l) + (2l - one(T))
x = erfinv(x)
return clamp(x * T(std) * 2 + T(mean), T(lo), T(hi))
return clamp(x * T(std) * T(2) + T(mean), T(lo), T(hi))
end
return xs
end
Expand Down
21 changes: 9 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,20 @@ end
@inline function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return ones(T, dims...)
end
@inline function __rand(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number}
@inline function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return rand(rng, T, args...)
end
@inline function __randn(rng::AbstractRNG, ::Type{T}, args...) where {T <: Number}
@inline function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return randn(rng, T, args...)
end

## Certain backends don't support sampling Complex numbers, so we avoid hitting those
## dispatches
@inline function __rand(rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number}
real_part = __rand(rng, T, args...)
imag_part = __rand(rng, T, args...)
return Complex.(real_part, imag_part)
end
@inline function __randn(
rng::AbstractRNG, ::Type{<:Complex{T}}, args...) where {T <: Number}
real_part = __randn(rng, T, args...)
imag_part = __randn(rng, T, args...)
return Complex.(real_part, imag_part)
for f in (:__rand, :__randn)
@eval @inline function $(f)(
rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
real_part = $(f)(rng, T, args...)
imag_part = $(f)(rng, T, args...)
return Complex{T}.(real_part, imag_part)
end
end
27 changes: 24 additions & 3 deletions test/initializers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ end
@testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin
using GPUArraysCore, LinearAlgebra

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES
@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES
# A matrix of dim = (m,n) with m > n should produce a QR decomposition.
# In the other case, the transpose should be taken to compute the QR decomposition.
for (rows, cols) in [(5, 3), (3, 5)]
Expand All @@ -35,11 +35,15 @@ end
end

@testset "Orthogonal Types $T" for T in (Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T
@test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T
end

@testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test orthogonal(rng, T, 3, 5) isa AbstractArray{T, 2}
@test orthogonal(rng, T, 3, 5) isa arrtype{T, 2}

Expand Down Expand Up @@ -69,7 +73,7 @@ end
@testitem "Sparse Initialization" setup=[SharedTestSetup] begin
using Statistics

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES
@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES
# sparse_init should yield an error for non 2-d dimensions
# sparse_init should yield no zero elements if sparsity < 0
# sparse_init should yield all zero elements if sparsity > 1
Expand All @@ -93,10 +97,14 @@ end
end

@testset "sparse_init Types $T" for T in (Float16, Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T
end

@testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2}
@test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2}

Expand All @@ -122,10 +130,17 @@ end
@testitem "Basic Initializations" setup=[SharedTestSetup] begin
using LinearAlgebra, Statistics

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES
@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64) in RNGS_ARRTYPES
@testset "Sizes and Types: $init" for init in [
zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal,
glorot_uniform, glorot_normal, truncated_normal, identity_init]
!supports_fp64 &&
(init === zeros32 ||
init === ones32 ||
init === rand32 ||
init === randn32) &&
continue

# Sizes
@test size(init(3)) == (3,)
@test size(init(rng, 3)) == (3,)
Expand All @@ -151,6 +166,8 @@ end
(randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64),
(randn16, Float16), (randnC16, ComplexF16), (randn32, Float32),
(randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)]
!supports_fp64 && (fp == Float64 || fp == ComplexF64) && continue

# Sizes
@test size(init(3)) == (3,)
@test size(init(rng, 3)) == (3,)
Expand All @@ -172,6 +189,8 @@ end
glorot_normal, truncated_normal, identity_init],
T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)

!supports_fp64 && (T == Float64 || T == ComplexF64) && continue

init === truncated_normal && !(T <: Real) && continue

@test init(T, 3) isa AbstractArray{T, 1}
Expand Down Expand Up @@ -206,6 +225,8 @@ end

@testset "Kwargs types" for T in (
Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
!supports_fp64 && (T == Float64 || T == ComplexF64) && continue

if (T <: Real)
@test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T
@test eltype(orthogonal(T, 2, 5; gain=1.0)) == T
Expand Down
16 changes: 11 additions & 5 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,31 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))
RNGS_ARRTYPES = []
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
append!(RNGS_ARRTYPES,
[(StableRNG(12345), AbstractArray), (Random.GLOBAL_RNG, AbstractArray)])
[(StableRNG(12345), AbstractArray, true), (Random.GLOBAL_RNG, AbstractArray, true)])
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda"
using CUDA, GPUArrays
append!(RNGS_ARRTYPES,
[(CUDA.default_rng(), CuArray), (GPUArrays.default_rng(CuArray), CuArray)])
[(CUDA.default_rng(), CuArray, true),
(GPUArrays.default_rng(CuArray), CuArray, true)])
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu"
using AMDGPU
append!(RNGS_ARRTYPES,
[(AMDGPU.rocrand_rng(), ROCArray), (AMDGPU.gpuarrays_rng(), ROCArray)])
[(AMDGPU.rocrand_rng(), ROCArray, true), (AMDGPU.gpuarrays_rng(), ROCArray, true)])
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal"
using Metal
push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray))
push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray, false))
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi"
using oneAPI
push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray))
using oneAPI: oneL0

supports_fp64 = oneL0.module_properties(first(oneAPI.devices())).fp64flags &
oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64

push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64))
end

export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP
Expand Down

0 comments on commit 4e62456

Please sign in to comment.