From aef869548b40e63468c381d2ef94b810cf1c68c9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 11 Jul 2024 20:30:10 -0700 Subject: [PATCH] feat: custom partial application implementation --- .JuliaFormatter.toml | 2 +- Project.toml | 6 ++-- src/WeightInitializers.jl | 5 ++- src/initializers.jl | 67 ++++++++++++++++++++++++------------- src/partial.jl | 70 +++++++++++++++++++++++++++++++++++++++ src/utils.jl | 3 -- 6 files changed, 120 insertions(+), 33 deletions(-) create mode 100644 src/partial.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 547dbee..f593e92 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,9 +1,9 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true join_lines_based_on_source = false always_for_in = true +annotate_untyped_fields_with_any = false diff --git a/Project.toml b/Project.toml index 0517ad8..c0d46ac 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,13 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.9" +version = "0.1.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -31,13 +31,13 @@ AMDGPU = "0.9.6" Aqua = "0.8.7" CUDA = "5.3.2" ChainRulesCore = "1.23" +ConcreteStructs = "0.2.3" Documenter = "1.5.0" ExplicitImports = "1.9.0" GPUArrays = "10.2" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" Metal = "1.1.0" -PartialFunctions = "1.2" Pkg = "1.10" Random = "1.10" ReTestItems = "1.24.0" diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 8838112..d115289 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,17 +1,16 @@ module WeightInitializers -#! format: off using ChainRulesCore: ChainRulesCore +using ConcreteStructs: @concrete using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr -using PartialFunctions: :$ using Random: Random, AbstractRNG, Xoshiro, shuffle using SpecialFunctions: SpecialFunctions, erf, erfinv using Statistics: Statistics, std -#! format: on const CRC = ChainRulesCore +include("partial.jl") include("utils.jl") include("initializers.jl") include("autodiff.jl") diff --git a/src/initializers.jl b/src/initializers.jl index 76bfdee..2e13417 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -318,33 +318,54 @@ end for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, :truncated_normal, :orthogonal, :sparse_init, :identity_init) NType = ifelse(initializer === :truncated_normal, Real, Number) - @eval function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), Float32, dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $initializer(rng, Float32, dims...; kwargs...) - end - @eval function ($initializer)( - ::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} - return $initializer(_default_rng(), T, dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) - end - @eval function ($initializer)( - rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} - return __partial_apply($initializer, ((rng, T), (; kwargs...))) + @eval begin + function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), Float32, dims...; kwargs...) + end + function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} + return $initializer(_default_rng(), T, dims...; kwargs...) + end + + # Partial application + function ($initializer)(rng::AbstractRNG; kwargs...) + return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs) + end + function ($initializer)(::Type{T}; kwargs...) where {T <: $NType} + return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs) + end + function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} + return PartialWeightInitializationFunction{T}($initializer, rng, kwargs) + end + function ($initializer)(; kwargs...) + return PartialWeightInitializationFunction{Nothing}( + $initializer, nothing, kwargs) + end end - @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand) initializer = Symbol(func, tp) - @eval function ($initializer)(dims::Integer...; kwargs...) - return $initializer(_default_rng(), dims...; kwargs...) - end - @eval function ($initializer)(rng::AbstractRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) + @eval begin + function ($initializer)(dims::Integer...; kwargs...) + return $initializer(_default_rng(), dims...; kwargs...) + end + function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + + # Partial application + function ($initializer)(rng::AbstractRNG; kwargs...) + return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs) + end + function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + function ($initializer)(; kwargs...) + return PartialWeightInitializationFunction{Missing}( + $initializer, nothing, kwargs) + end end - @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) end diff --git a/src/partial.jl b/src/partial.jl new file mode 100644 index 0000000..7e2c499 --- /dev/null +++ b/src/partial.jl @@ -0,0 +1,70 @@ +@concrete struct PartialWeightInitializationFunction{T} <: Function + f <: Function + rng <: Union{Nothing, AbstractRNG} + kwargs +end + +function Base.show( + io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T} + print(io, "$(f.f)(") + f.rng !== nothing ? print(io, "$(f.rng), ") : print(io, "rng, ") + if T === Nothing + print(io, "::Type{T}, ") + else + T !== Missing ? print(io, "$(T), ") : nothing + end + print(io, "dims...") + kwargs_str = String[] + for (k, v) in pairs(f.kwargs) + push!(kwargs_str, "$(k)=$(v)") + end + length(kwargs_str) > 0 && print(io, "; ", join(kwargs_str, ", ")) + print(io, ")") +end + +# ::Type{T} is already specified +function (f::PartialWeightInitializationFunction{T, F, <:AbstractRNG})( + dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{T, F, Nothing})( + rng::AbstractRNG; kwargs...) where {T <: Number, F} + return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{T, F, Nothing})( + rng::AbstractRNG, dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(rng, T, dims...; f.kwargs..., kwargs...) +end + +# ::Type{T} is not needed +function (f::PartialWeightInitializationFunction{Missing, F, <:AbstractRNG})( + dims::Integer...; kwargs...) where {F} + return f.f(f.rng, dims...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( + rng::AbstractRNG; kwargs...) where {F} + return PartialWeightInitializationFunction{Missing}( + f.f, rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{Missing, F, Nothing})( + rng::AbstractRNG, dims::Integer...; kwargs...) where {F} + return f.f(rng, dims...; f.kwargs..., kwargs...) +end + +# ::Type{T} is not specified +function (f::PartialWeightInitializationFunction{Nothing, F, Union{<:AbstractRNG, Nothing}})( + ::Type{T}; kwargs...) where {T <: Number, F} + return PartialWeightInitializationFunction{T}(f.f, f.rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{Nothing, F, <:AbstractRNG})( + ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(f.rng, T, dims...; f.kwargs..., kwargs...) +end +function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( + rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number, F} + return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...)) +end +function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})( + rng::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F} + return f.f(rng, T, dims...; f.kwargs..., kwargs...) +end diff --git a/src/utils.jl b/src/utils.jl index 3b9c618..1672c3a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -7,9 +7,6 @@ @inline _default_rng() = Xoshiro(1234) -# This is needed if using `PartialFunctions.$` inside @eval block -@inline __partial_apply(fn, inp) = fn$inp - const NAME_TO_DIST = Dict( :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", :randn => "random numbers from a standard normal distribution",