Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move device agnostic functions to DeviceAgnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 19, 2024
1 parent d23c1a3 commit 443eb83
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 76 deletions.
14 changes: 7 additions & 7 deletions ext/WeightInitializersAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@ module WeightInitializersAMDGPUExt
using AMDGPU: AMDGPU, ROCArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.zeros(T, dims...)
end
function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.ones(T, dims...)
end

function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.zeros(T, dims...)
end
function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.ones(T, dims...)
end
function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = ROCArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = ROCArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
14 changes: 7 additions & 7 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,34 @@ module WeightInitializersCUDAExt
using CUDA: CUDA, CURAND, CuArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG}

function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.zeros(T, dims...)
end
function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.ones(T, dims...)
end

function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.zeros(T, dims...)
end
function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.ones(T, dims...)
end
function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = CuArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = CuArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
16 changes: 8 additions & 8 deletions ext/WeightInitializersGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
module WeightInitializersGPUArraysExt

using GPUArrays: RNG
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

for f in (:__zeros, :__ones, :__rand, :__randn)
@eval function WeightInitializers.$(f)(
for f in (:zeros, :ones, :rand, :randn)
@eval function DeviceAgnostic.$(f)(
rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return WeightInitializers.$(f)(rng, rng.state, T, dims...)
return DeviceAgnostic.$(f)(rng, rng.state, T, dims...)
end
end

## Certain backends don't support sampling Complex numbers, so we avoid hitting those
## dispatches
for f in (:__rand, :__randn)
@eval function WeightInitializers.$(f)(
for f in (:rand, :randn)
@eval function DeviceAgnostic.$(f)(
rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
real_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
real_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...)
imag_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...)
return Complex{T}.(real_part, imag_part)
end
end
Expand Down
10 changes: 5 additions & 5 deletions ext/WeightInitializersMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ module WeightInitializersMetalExt
using Metal: Metal, MtlArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
return Metal.zeros(T, dims...)
end
function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
return Metal.ones(T, dims...)
end
function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = MtlArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = MtlArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
10 changes: 5 additions & 5 deletions ext/WeightInitializersoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ module WeightInitializersoneAPIExt
using oneAPI: oneAPI, oneArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
return oneAPI.zeros(T, dims...)
end
function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
return oneAPI.ones(T, dims...)
end
function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = oneArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = oneArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
16 changes: 11 additions & 5 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
module WeightInitializers

using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore
using ChainRulesCore: @non_differentiable
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using Random: Random, AbstractRNG, shuffle
using SpecialFunctions: SpecialFunctions, erfinv # Move to Ext in v2.0
using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0
using Statistics: Statistics, std

const CRC = ChainRulesCore

include("partial.jl")
include("utils.jl")
include("initializers.jl")
include("autodiff.jl")

# Mark the functions as non-differentiable
for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval @non_differentiable $(f)(::Any...)
end

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
Expand Down
13 changes: 0 additions & 13 deletions src/autodiff.jl

This file was deleted.

25 changes: 12 additions & 13 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand
name = Symbol(fname, T)
docstring = Utils.generic_docstring(string(name))
TP = Utils.NUM_TO_FPOINT[Symbol(T)]
__fname = Symbol("__", fname)

@eval begin
@doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $__fname(rng, $TP, dims...; kwargs...)
return DeviceAgnostic.$(fname)(rng, $TP, dims...; kwargs...)
end
end
end
Expand All @@ -29,7 +28,7 @@ artificial intelligence and statistics_. 2010.
function glorot_uniform(
rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number}
scale = T(gain) * sqrt(T(24) / sum(Utils.nfan(dims...)))
x = __rand(rng, T, dims...)
x = DeviceAgnostic.rand(rng, T, dims...)
half = T(0.5)
@. x = (x - half) * scale
return x
Expand All @@ -52,7 +51,7 @@ artificial intelligence and statistics_. 2010.
function glorot_normal(
rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number}
std = T(gain) * sqrt(T(2) / sum(Utils.nfan(dims...)))
x = __randn(rng, T, dims...)
x = DeviceAgnostic.randn(rng, T, dims...)
x .*= std
return x
end
Expand All @@ -73,7 +72,7 @@ vision_. 2015.
function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=T(2)) where {T <: Number}
bound = T(3) * T(gain) / sqrt(T(first(Utils.nfan(dims...))))
x = __rand(rng, T, dims...)
x = DeviceAgnostic.rand(rng, T, dims...)
half = T(0.5)
@. x = (x - half) * 2 * bound
return x
Expand All @@ -95,7 +94,7 @@ vision_. 2015.
function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=T(2)) where {T <: Number}
std = T(gain) / sqrt(T(first(Utils.nfan(dims...))))
x = __randn(rng, T, dims...)
x = DeviceAgnostic.randn(rng, T, dims...)
x .*= std
return x
end
Expand All @@ -116,7 +115,7 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(
end
l = Utils.norm_cdf((T(lo) - T(mean)) / T(std))
u = Utils.norm_cdf((T(hi) - T(mean)) / T(std))
xs = __rand(rng, T, dims...)
xs = DeviceAgnostic.rand(rng, T, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - one(T))
x = erfinv(x)
Expand Down Expand Up @@ -158,7 +157,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end])
rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain)))

mat = __randn(rng, T, rows, cols)
mat = DeviceAgnostic.randn(rng, T, rows, cols)
Q, R = qr(mat)
mat .= Q * sign.(Diagonal(R)) .* T(gain)

Expand Down Expand Up @@ -218,11 +217,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
initialization."))
end

rows, cols = dims
rows, _ = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)

sparse_array = __randn(rng, T, dims...)
sparse_array = DeviceAgnostic.randn(rng, T, dims...)
sparse_array .*= T(std)
fill!(view(sparse_array, 1:num_zeros, :), zero(T))

Expand Down Expand Up @@ -293,11 +292,11 @@ julia> identity_init(Xoshiro(123), Float32, 3, 3, 1, 1; gain=1.5)
"""
function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization
length(dims) == 1 && return DeviceAgnostic.zeros(rng, T, dims...) # Bias initialization

if length(dims) == 2
rows, cols = dims
mat = __zeros(rng, T, rows, cols)
mat = DeviceAgnostic.zeros(rng, T, rows, cols)
diag_indices = 1:min(rows, cols)
fill!(view(mat, diag_indices, diag_indices), T(gain))
return circshift(mat, shift)
Expand All @@ -306,7 +305,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
# Convolution or more dimensions
nin, nout = dims[end - 1], dims[end]
centers = map(d -> cld(d, 2), dims[1:(end - 2)])
weights = __zeros(rng, T, dims...)
weights = DeviceAgnostic.zeros(rng, T, dims...)
@allowscalar for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = T(gain)
Expand Down
33 changes: 20 additions & 13 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,34 @@ end

end

module DeviceAgnostic

using ChainRulesCore: @non_differentiable
using Random: AbstractRNG

# Helpers for device agnostic initializers
function __zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return zeros(T, dims...)
function zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return Base.zeros(T, dims...)
end
function __ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return ones(T, dims...)
ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} = Base.ones(T, dims...)
function rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return Base.rand(rng, T, args...)
end
function __rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return rand(rng, T, args...)
end
function __randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return randn(rng, T, args...)
function randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return Base.randn(rng, T, args...)
end

## Certain backends don't support sampling Complex numbers, so we avoid hitting those
## dispatches
for f in (:__rand, :__randn)
for f in (:rand, :randn)
@eval function $(f)(
rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
real_part = $(f)(rng, T, args...)
imag_part = $(f)(rng, T, args...)
return Complex{T}.(real_part, imag_part)
return Complex{T}.($(f)(rng, T, args...), $(f)(rng, T, args...))
end
end

for f in (:zeros, :ones, :rand, :randn)
@eval @non_differentiable $f(::Any...)
end

end

0 comments on commit 443eb83

Please sign in to comment.