From b7e4ae403dc792899cf7e24696d58fba2cc60cf6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 21:07:33 -0700 Subject: [PATCH] fix: add missing dispatch --- Project.toml | 2 ++ src/WeightInitializers.jl | 1 + src/initializers.jl | 6 +++++- src/partial.jl | 16 +++++++++++++++- test/initializers_tests.jl | 37 +++++++++++++++++++++++++++++++++++++ 5 files changed, 60 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c0d46ac..bf04f08 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Avik Pal and contributors"] version = "0.1.10" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -29,6 +30,7 @@ WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] [compat] AMDGPU = "0.9.6" Aqua = "0.8.7" +ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index d115289..af3c5ef 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,5 +1,6 @@ module WeightInitializers +using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore using ConcreteStructs: @concrete using GPUArraysCore: @allowscalar diff --git a/src/initializers.jl b/src/initializers.jl index 2e13417..57d6d8d 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -153,7 +153,7 @@ deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} - @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + @argcheck length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) @@ -355,6 +355,10 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T} throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) end + function ($initializer)( + ::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end # Partial application function ($initializer)(rng::AbstractRNG; kwargs...) diff --git a/src/partial.jl b/src/partial.jl index a4d34b0..d9b054c 100644 --- a/src/partial.jl +++ b/src/partial.jl @@ -7,7 +7,11 @@ end function Base.show( io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T} print(io, "$(f.f)(") - f.rng !== nothing ? print(io, "$(f.rng), ") : print(io, "rng, ") + if f.rng !== nothing + print(io, "$(nameof(typeof(f.rng)))(...), ") + else + print(io, "rng, ") + end if T === Nothing print(io, "::Type{T}, ") else @@ -27,7 +31,17 @@ function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) return f.f(f.rng, args...; f.kwargs..., kwargs...) end +function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( + rng::AbstractRNG, args...; kwargs...) + @argcheck f.rng === nothing + return f.f(rng, args...; f.kwargs..., kwargs...) +end function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number} f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) return f.f(f.rng, T, args...; f.kwargs..., kwargs...) end +function (f::PartialWeightInitializationFunction{T})( + rng::AbstractRNG, args...; kwargs...) where {T <: Number} + @argcheck f.rng === nothing + return f.f(rng, T, args...; f.kwargs..., kwargs...) +end diff --git a/test/initializers_tests.jl b/test/initializers_tests.jl index af968f8..39d6156 100644 --- a/test/initializers_tests.jl +++ b/test/initializers_tests.jl @@ -53,14 +53,17 @@ end @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} cl = orthogonal(rng) + display(cl) @test cl(T, 3, 5) isa arrtype{T, 2} cl = orthogonal(rng, T) + display(cl) @test cl(3, 5) isa arrtype{T, 2} end @testset "Orthogonal Closure" begin cl = orthogonal(;) + display(cl) # Sizes @test size(cl(3, 4)) == (3, 4) @@ -114,17 +117,22 @@ end @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} cl = sparse_init(rng; sparsity=0.5) + display(cl) @test cl(T, 3, 5) isa arrtype{T, 2} cl = sparse_init(rng, T; sparsity=0.5) + display(cl) @test cl(3, 5) isa arrtype{T, 2} end @testset "sparse_init Closure" begin cl = sparse_init(; sparsity=0.5) + display(cl) + # Sizes @test size(cl(3, 4)) == (3, 4) @test size(cl(rng, 3, 4)) == (3, 4) + # Type @test eltype(cl(4, 2)) == Float32 @test eltype(cl(rng, 4, 2)) == Float32 @@ -158,11 +166,14 @@ end @test size(init(rng, 3, 4)) == (3, 4) @test size(init(3, 4, 5)) == (3, 4, 5) @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type @test eltype(init(rng, 4, 2)) == Float32 @test eltype(init(4, 2)) == Float32 + # RNG Closure cl = init(rng) + display(cl) @test cl(3) isa arrtype{Float32, 1} @test cl(3, 5) isa arrtype{Float32, 2} end @@ -185,13 +196,28 @@ end @test size(init(rng, 3, 4)) == (3, 4) @test size(init(3, 4, 5)) == (3, 4, 5) @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type @test eltype(init(rng, 4, 2)) == fp @test eltype(init(4, 2)) == fp + # RNG Closure cl = init(rng) + display(cl) @test cl(3) isa arrtype{fp, 1} @test cl(3, 5) isa arrtype{fp, 2} + + # Kwargs closure + cl = init(;) + display(cl) + @test cl(rng, 3) isa arrtype{fp, 1} + @test cl(rng, 3, 5) isa arrtype{fp, 2} + + # throw error on type as input + @test_throws ArgumentError init(Float32) + @test_throws ArgumentError init(Float32, 3, 5) + @test_throws ArgumentError init(rng, Float32) + @test_throws ArgumentError init(rng, Float32, 3, 5) end @testset "AbstractArray Type: $init $T" for init in [ @@ -216,12 +242,20 @@ end @test init(rng, T, 3, 5) isa arrtype{T, 2} cl = init(rng) + display(cl) @test cl(T, 3) isa arrtype{T, 1} @test cl(T, 3, 5) isa arrtype{T, 2} cl = init(rng, T) + display(cl) @test cl(3) isa arrtype{T, 1} @test cl(3, 5) isa arrtype{T, 2} + + cl = init(T) + display(cl) + @test cl(3) isa Array{T, 1} + @test cl(3, 5) isa Array{T, 2} + @test cl(rng, 3, 5) isa arrtype{T, 2} end @testset "Closure: $init" for init in [ @@ -233,6 +267,8 @@ end end cl = init(;) + display(cl) + # Sizes @test size(cl(3)) == (3,) @test size(cl(rng, 3)) == (3,) @@ -240,6 +276,7 @@ end @test size(cl(rng, 3, 4)) == (3, 4) @test size(cl(3, 4, 5)) == (3, 4, 5) @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type @test eltype(cl(4, 2)) == Float32 @test eltype(cl(rng, 4, 2)) == Float32