diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index af3c5ef..253b5fa 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -2,11 +2,10 @@ module WeightInitializers using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore -using ConcreteStructs: @concrete using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using Random: Random, AbstractRNG, Xoshiro, shuffle -using SpecialFunctions: SpecialFunctions, erf, erfinv +using SpecialFunctions: SpecialFunctions, erf, erfinv # Move to Ext in v2.0 using Statistics: Statistics, std const CRC = ChainRulesCore diff --git a/src/initializers.jl b/src/initializers.jl index 57d6d8d..981746a 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -331,17 +331,16 @@ for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_ # Partial application function ($initializer)(rng::AbstractRNG; kwargs...) - return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs) + return PartialFunction.Partial{Nothing}($initializer, rng, kwargs) end function ($initializer)(::Type{T}; kwargs...) where {T <: $NType} - return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs) + return PartialFunction.Partial{T}($initializer, nothing, kwargs) end function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} - return PartialWeightInitializationFunction{T}($initializer, rng, kwargs) + return PartialFunction.Partial{T}($initializer, rng, kwargs) end function ($initializer)(; kwargs...) - return PartialWeightInitializationFunction{Nothing}( - $initializer, nothing, kwargs) + return PartialFunction.Partial{Nothing}($initializer, nothing, kwargs) end end end @@ -362,14 +361,13 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand # Partial application function ($initializer)(rng::AbstractRNG; kwargs...) - return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs) + return PartialFunction.Partial{Missing}($initializer, rng, kwargs) end function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T} throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) end function ($initializer)(; kwargs...) - return PartialWeightInitializationFunction{Missing}( - $initializer, nothing, kwargs) + return PartialFunction.Partial{Missing}($initializer, nothing, kwargs) end end end diff --git a/src/partial.jl b/src/partial.jl index d9b054c..52cde29 100644 --- a/src/partial.jl +++ b/src/partial.jl @@ -1,11 +1,16 @@ -@concrete struct PartialWeightInitializationFunction{T} <: Function +module PartialFunction + +using ArgCheck: @argcheck +using ConcreteStructs: @concrete +using Random: AbstractRNG + +@concrete struct Partial{T} <: Function f <: Function rng <: Union{Nothing, AbstractRNG} kwargs end -function Base.show( - io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T} +function Base.show(io::IO, ::MIME"text/plain", f::Partial{T}) where {T} print(io, "$(f.f)(") if f.rng !== nothing print(io, "$(nameof(typeof(f.rng)))(...), ") @@ -26,22 +31,21 @@ function Base.show( print(io, ")") end -function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( - args...; kwargs...) +function (f::Partial{<:Union{Nothing, Missing}})(args...; kwargs...) f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) return f.f(f.rng, args...; f.kwargs..., kwargs...) end -function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})( - rng::AbstractRNG, args...; kwargs...) +function (f::Partial{<:Union{Nothing, Missing}})(rng::AbstractRNG, args...; kwargs...) @argcheck f.rng === nothing return f.f(rng, args...; f.kwargs..., kwargs...) end -function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number} +function (f::Partial{T})(args...; kwargs...) where {T <: Number} f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) return f.f(f.rng, T, args...; f.kwargs..., kwargs...) end -function (f::PartialWeightInitializationFunction{T})( - rng::AbstractRNG, args...; kwargs...) where {T <: Number} +function (f::Partial{T})(rng::AbstractRNG, args...; kwargs...) where {T <: Number} @argcheck f.rng === nothing return f.f(rng, T, args...; f.kwargs..., kwargs...) end + +end