Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: add missing dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 12, 2024
1 parent 84c660f commit 25ea875
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.10"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Expand All @@ -29,6 +30,7 @@ WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]
[compat]
AMDGPU = "0.9.6"
Aqua = "0.8.7"
ArgCheck = "2.3.0"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
Expand Down
1 change: 1 addition & 0 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module WeightInitializers

using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using GPUArraysCore: @allowscalar
Expand Down
6 changes: 5 additions & 1 deletion src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
"""
function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=T(1.0)) where {T <: Number}
@assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed"
@argcheck length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed"

rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end])
rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain)))
Expand Down Expand Up @@ -355,6 +355,10 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end
function ($initializer)(
::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
Expand Down
16 changes: 15 additions & 1 deletion src/partial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ end
function Base.show(
io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T}
print(io, "$(f.f)(")
f.rng !== nothing ? print(io, "$(f.rng), ") : print(io, "rng, ")
if f.rng !== nothing
print(io, "$(nameof(typeof(f.rng)))(...), ")
else
print(io, "rng, ")
end
if T === Nothing
print(io, "::Type{T}, ")
else
Expand All @@ -27,7 +31,17 @@ function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...)
return f.f(f.rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
rng::AbstractRNG, args...; kwargs...)
@argcheck f.rng === nothing
return f.f(rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number}
f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...)
return f.f(f.rng, T, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(
rng::AbstractRNG, args...; kwargs...) where {T <: Number}
@argcheck f.rng === nothing
return f.f(rng, T, args...; f.kwargs..., kwargs...)
end
30 changes: 30 additions & 0 deletions test/initializers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ end
@test orthogonal(rng, T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng)
display(cl)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng, T)
display(cl)
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "Orthogonal Closure" begin
cl = orthogonal(;)
display(cl)

# Sizes
@test size(cl(3, 4)) == (3, 4)
Expand Down Expand Up @@ -114,17 +117,22 @@ end
@test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2}

cl = sparse_init(rng; sparsity=0.5)
display(cl)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = sparse_init(rng, T; sparsity=0.5)
display(cl)
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "sparse_init Closure" begin
cl = sparse_init(; sparsity=0.5)
display(cl)

# Sizes
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
Expand Down Expand Up @@ -158,11 +166,14 @@ end
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(init(rng, 4, 2)) == Float32
@test eltype(init(4, 2)) == Float32

# RNG Closure
cl = init(rng)
display(cl)
@test cl(3) isa arrtype{Float32, 1}
@test cl(3, 5) isa arrtype{Float32, 2}
end
Expand All @@ -185,13 +196,21 @@ end
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(init(rng, 4, 2)) == fp
@test eltype(init(4, 2)) == fp

# RNG Closure
cl = init(rng)
display(cl)
@test cl(3) isa arrtype{fp, 1}
@test cl(3, 5) isa arrtype{fp, 2}

# throw error on type as input
@test_throws ArgumentError init(Float32)
@test_throws ArgumentError init(Float32, 3, 5)
@test_throws ArgumentError init(rng, Float32, 3, 5)
end

@testset "AbstractArray Type: $init $T" for init in [
Expand All @@ -216,12 +235,20 @@ end
@test init(rng, T, 3, 5) isa arrtype{T, 2}

cl = init(rng)
display(cl)
@test cl(T, 3) isa arrtype{T, 1}
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = init(rng, T)
display(cl)
@test cl(3) isa arrtype{T, 1}
@test cl(3, 5) isa arrtype{T, 2}

cl = init(T)
display(cl)
@test cl(3) isa Array{T, 1}
@test cl(3, 5) isa Array{T, 2}
@test cl(rng, 3, 5) isa arrtype{T, 2}
end

@testset "Closure: $init" for init in [
Expand All @@ -233,13 +260,16 @@ end
end

cl = init(;)
display(cl)

# Sizes
@test size(cl(3)) == (3,)
@test size(cl(rng, 3)) == (3,)
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
@test size(cl(3, 4, 5)) == (3, 4, 5)
@test size(cl(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
Expand Down

0 comments on commit 25ea875

Please sign in to comment.