From 443eb8356526b287b6ed34dee0b1daec6790c556 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 20:47:23 -0700 Subject: [PATCH] refactor: move device agnostic functions to `DeviceAgnostic` --- ext/WeightInitializersAMDGPUExt.jl | 14 ++++++------ ext/WeightInitializersCUDAExt.jl | 14 ++++++------ ext/WeightInitializersGPUArraysExt.jl | 16 ++++++------- ext/WeightInitializersMetalExt.jl | 10 ++++---- ext/WeightInitializersoneAPIExt.jl | 10 ++++---- src/WeightInitializers.jl | 16 +++++++++---- src/autodiff.jl | 13 ----------- src/initializers.jl | 25 ++++++++++---------- src/utils.jl | 33 ++++++++++++++++----------- 9 files changed, 75 insertions(+), 76 deletions(-) delete mode 100644 src/autodiff.jl diff --git a/ext/WeightInitializersAMDGPUExt.jl b/ext/WeightInitializersAMDGPUExt.jl index 63031c5..ad0fa20 100644 --- a/ext/WeightInitializersAMDGPUExt.jl +++ b/ext/WeightInitializersAMDGPUExt.jl @@ -3,32 +3,32 @@ module WeightInitializersAMDGPUExt using AMDGPU: AMDGPU, ROCArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.ones(T, dims...) end -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} return AMDGPU.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} y = ROCArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} y = ROCArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index 6dd9e1a..db7573f 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -3,34 +3,34 @@ module WeightInitializersCUDAExt using CUDA: CUDA, CURAND, CuArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} return CUDA.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} y = CuArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} y = CuArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/ext/WeightInitializersGPUArraysExt.jl b/ext/WeightInitializersGPUArraysExt.jl index 21baf96..78e0ec6 100644 --- a/ext/WeightInitializersGPUArraysExt.jl +++ b/ext/WeightInitializersGPUArraysExt.jl @@ -1,22 +1,22 @@ module WeightInitializersGPUArraysExt using GPUArrays: RNG -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -for f in (:__zeros, :__ones, :__rand, :__randn) - @eval function WeightInitializers.$(f)( +for f in (:zeros, :ones, :rand, :randn) + @eval function DeviceAgnostic.$(f)( rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number} - return WeightInitializers.$(f)(rng, rng.state, T, dims...) + return DeviceAgnostic.$(f)(rng, rng.state, T, dims...) end end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches -for f in (:__rand, :__randn) - @eval function WeightInitializers.$(f)( +for f in (:rand, :randn) + @eval function DeviceAgnostic.$(f)( rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} - real_part = WeightInitializers.$(f)(rng, rng.state, T, args...) - imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...) + real_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...) + imag_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...) return Complex{T}.(real_part, imag_part) end end diff --git a/ext/WeightInitializersMetalExt.jl b/ext/WeightInitializersMetalExt.jl index 70045a3..79e5b34 100644 --- a/ext/WeightInitializersMetalExt.jl +++ b/ext/WeightInitializersMetalExt.jl @@ -3,23 +3,23 @@ module WeightInitializersMetalExt using Metal: Metal, MtlArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} return Metal.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} return Metal.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} y = MtlArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} y = MtlArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/ext/WeightInitializersoneAPIExt.jl b/ext/WeightInitializersoneAPIExt.jl index e3c7a7e..e1827e1 100644 --- a/ext/WeightInitializersoneAPIExt.jl +++ b/ext/WeightInitializersoneAPIExt.jl @@ -3,23 +3,23 @@ module WeightInitializersoneAPIExt using oneAPI: oneAPI, oneArray using GPUArrays: RNG using Random: Random -using WeightInitializers: WeightInitializers +using WeightInitializers: DeviceAgnostic -function WeightInitializers.__zeros( +function DeviceAgnostic.zeros( ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} return oneAPI.zeros(T, dims...) end -function WeightInitializers.__ones( +function DeviceAgnostic.ones( ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} return oneAPI.ones(T, dims...) end -function WeightInitializers.__rand( +function DeviceAgnostic.rand( rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} y = oneArray{T}(undef, dims...) Random.rand!(rng, y) return y end -function WeightInitializers.__randn( +function DeviceAgnostic.randn( rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} y = oneArray{T}(undef, dims...) Random.randn!(rng, y) diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 8a898e2..e96eebb 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,19 +1,25 @@ module WeightInitializers using ArgCheck: @argcheck -using ChainRulesCore: ChainRulesCore +using ChainRulesCore: @non_differentiable using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using Random: Random, AbstractRNG, shuffle -using SpecialFunctions: SpecialFunctions, erfinv # Move to Ext in v2.0 +using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0 using Statistics: Statistics, std -const CRC = ChainRulesCore - include("partial.jl") include("utils.jl") include("initializers.jl") -include("autodiff.jl") + +# Mark the functions as non-differentiable +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval @non_differentiable $(f)(::Any...) +end export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 diff --git a/src/autodiff.jl b/src/autodiff.jl deleted file mode 100644 index ca3f8a8..0000000 --- a/src/autodiff.jl +++ /dev/null @@ -1,13 +0,0 @@ -# Wrappers -for f in (:__zeros, :__ones, :__rand, :__randn) - @eval CRC.@non_differentiable $(f)(::Any...) -end - -# Mark the functions as non-differentiable -for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, - :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, - :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, - :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, - :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] - @eval CRC.@non_differentiable $(f)(::Any...) -end diff --git a/src/initializers.jl b/src/initializers.jl index 4316fec..81de6a1 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -2,11 +2,10 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand name = Symbol(fname, T) docstring = Utils.generic_docstring(string(name)) TP = Utils.NUM_TO_FPOINT[Symbol(T)] - __fname = Symbol("__", fname) @eval begin @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $__fname(rng, $TP, dims...; kwargs...) + return DeviceAgnostic.$(fname)(rng, $TP, dims...; kwargs...) end end end @@ -29,7 +28,7 @@ artificial intelligence and statistics_. 2010. function glorot_uniform( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(Utils.nfan(dims...))) - x = __rand(rng, T, dims...) + x = DeviceAgnostic.rand(rng, T, dims...) half = T(0.5) @. x = (x - half) * scale return x @@ -52,7 +51,7 @@ artificial intelligence and statistics_. 2010. function glorot_normal( rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(Utils.nfan(dims...))) - x = __randn(rng, T, dims...) + x = DeviceAgnostic.randn(rng, T, dims...) x .*= std return x end @@ -73,7 +72,7 @@ vision_. 2015. function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} bound = √T(3) * T(gain) / sqrt(T(first(Utils.nfan(dims...)))) - x = __rand(rng, T, dims...) + x = DeviceAgnostic.rand(rng, T, dims...) half = T(0.5) @. x = (x - half) * 2 * bound return x @@ -95,7 +94,7 @@ vision_. 2015. function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} std = T(gain) / sqrt(T(first(Utils.nfan(dims...)))) - x = __randn(rng, T, dims...) + x = DeviceAgnostic.randn(rng, T, dims...) x .*= std return x end @@ -116,7 +115,7 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( end l = Utils.norm_cdf((T(lo) - T(mean)) / T(std)) u = Utils.norm_cdf((T(hi) - T(mean)) / T(std)) - xs = __rand(rng, T, dims...) + xs = DeviceAgnostic.rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - one(T)) x = erfinv(x) @@ -158,7 +157,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) - mat = __randn(rng, T, rows, cols) + mat = DeviceAgnostic.randn(rng, T, rows, cols) Q, R = qr(mat) mat .= Q * sign.(Diagonal(R)) .* T(gain) @@ -218,11 +217,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; initialization.")) end - rows, cols = dims + rows, _ = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = __randn(rng, T, dims...) + sparse_array = DeviceAgnostic.randn(rng, T, dims...) sparse_array .*= T(std) fill!(view(sparse_array, 1:num_zeros, :), zero(T)) @@ -293,11 +292,11 @@ julia> identity_init(Xoshiro(123), Float32, 3, 3, 1, 1; gain=1.5) """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization + length(dims) == 1 && return DeviceAgnostic.zeros(rng, T, dims...) # Bias initialization if length(dims) == 2 rows, cols = dims - mat = __zeros(rng, T, rows, cols) + mat = DeviceAgnostic.zeros(rng, T, rows, cols) diag_indices = 1:min(rows, cols) fill!(view(mat, diag_indices, diag_indices), T(gain)) return circshift(mat, shift) @@ -306,7 +305,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; # Convolution or more dimensions nin, nout = dims[end - 1], dims[end] centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = __zeros(rng, T, dims...) + weights = DeviceAgnostic.zeros(rng, T, dims...) @allowscalar for i in 1:min(nin, nout) index = (centers..., i, i) weights[index...] = T(gain) diff --git a/src/utils.jl b/src/utils.jl index 6ba097f..201283d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -50,27 +50,34 @@ end end +module DeviceAgnostic + +using ChainRulesCore: @non_differentiable +using Random: AbstractRNG + # Helpers for device agnostic initializers -function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} - return zeros(T, dims...) +function zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return Base.zeros(T, dims...) end -function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} - return ones(T, dims...) +ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} = Base.ones(T, dims...) +function rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} + return Base.rand(rng, T, args...) end -function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} - return rand(rng, T, args...) -end -function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} - return randn(rng, T, args...) +function randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} + return Base.randn(rng, T, args...) end ## Certain backends don't support sampling Complex numbers, so we avoid hitting those ## dispatches -for f in (:__rand, :__randn) +for f in (:rand, :randn) @eval function $(f)( rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} - real_part = $(f)(rng, T, args...) - imag_part = $(f)(rng, T, args...) - return Complex{T}.(real_part, imag_part) + return Complex{T}.($(f)(rng, T, args...), $(f)(rng, T, args...)) end end + +for f in (:zeros, :ones, :rand, :randn) + @eval @non_differentiable $f(::Any...) +end + +end