diff --git a/Project.toml b/Project.toml index 97d73c1..67384d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index 45b91df..ac07b42 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -30,7 +30,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* std + sparse_array = randn(rng, T, dims...) .* T(std) sparse_array[1:num_zeros, :] .= CUDA.zero(T) return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) @@ -46,7 +46,7 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; rows, cols = dims mat = CUDA.zeros(T, rows, cols) diag_indices = 1:min(rows, cols) - CUDA.fill!(view(mat, diag_indices, diag_indices), gain) + CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain)) return CUDA.circshift(mat, shift) else # Convolution or more dimensions @@ -56,7 +56,7 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; #we should really find a better way to do this CUDA.@allowscalar for i in 1:min(nin, nout) index = (centers..., i, i) - weights[index...] = gain + weights[index...] = T(gain) end return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end diff --git a/src/initializers.jl b/src/initializers.jl index 357b41c..fd31046 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -75,7 +75,7 @@ vision_. 2015. """ function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} - bound = √T(3) * gain / sqrt(T(first(_nfan(dims...)))) + bound = √T(3) * T(gain) / sqrt(T(first(_nfan(dims...)))) return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound end @@ -94,7 +94,7 @@ vision_. 2015. """ function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=√T(2)) where {T <: Number} - std = gain / sqrt(T(first(_nfan(dims...)))) + std = T(gain) / sqrt(T(first(_nfan(dims...)))) return randn(rng, T, dims...) .* std end @@ -111,13 +111,13 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T( if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." end - l = _norm_cdf((lo - mean) / std) - u = _norm_cdf((hi - mean) / std) + l = _norm_cdf((T(lo) - T(mean)) / T(std)) + u = _norm_cdf((T(hi) - T(mean)) / T(std)) xs = rand(rng, T, dims...) broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) x = erfinv(x) - return clamp(x * std * √2 + mean, lo, hi) + return clamp(x * T(std) * √2 + T(mean), T(lo), T(hi)) end return xs end @@ -162,7 +162,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; end if rows < cols - return permutedims(orthogonal(rng, T, cols, rows; gain)) + return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) end mat = randn(rng, T, rows, cols) @@ -236,7 +236,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* std + sparse_array = randn(rng, T, dims...) .* T(std) sparse_array[1:num_zeros, :] .= zero(T) return mapslices(shuffle, sparse_array; dims=1) end @@ -313,7 +313,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = dims mat = zeros(T, rows, cols) for i in 1:min(rows, cols) - mat[i, i] = gain + mat[i, i] = T(gain) end return circshift(mat, shift) else @@ -323,7 +323,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; weights = zeros(T, dims...) for i in 1:min(nin, nout) index = (centers..., i, i) - weights[index...] = gain + weights[index...] = T(gain) end return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end diff --git a/test/runtests.jl b/test/runtests.jl index a2afe08..aca13c8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -114,6 +114,20 @@ const GROUP = get(ENV, "GROUP", "All") @test eltype(cl(rng, 4, 2)) == Float32 end + @testset "Kwargs types" for T in ( + Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + if (T <: Real) + @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T + @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T + end + @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T + @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T + @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T + @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T + @test eltype(identity_init(T, 2, 5; gain=1.0)) == T + @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T + end + @testset "kaiming" begin # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)