Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: custom partial application implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 12, 2024
1 parent d295cd5 commit aef8695
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
join_lines_based_on_source = false
always_for_in = true
annotate_untyped_fields_with_any = false
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.9"
version = "0.1.10"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -31,13 +31,13 @@ AMDGPU = "0.9.6"
Aqua = "0.8.7"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
Documenter = "1.5.0"
ExplicitImports = "1.9.0"
GPUArrays = "10.2"
GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
Metal = "1.1.0"
PartialFunctions = "1.2"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.24.0"
Expand Down
5 changes: 2 additions & 3 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
module WeightInitializers

#! format: off
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using PartialFunctions: :$
using Random: Random, AbstractRNG, Xoshiro, shuffle
using SpecialFunctions: SpecialFunctions, erf, erfinv
using Statistics: Statistics, std
#! format: on

const CRC = ChainRulesCore

include("partial.jl")
include("utils.jl")
include("initializers.jl")
include("autodiff.jl")
Expand Down
67 changes: 44 additions & 23 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,33 +318,54 @@ end
for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal,
:truncated_normal, :orthogonal, :sparse_init, :identity_init)
NType = ifelse(initializer === :truncated_normal, Real, Number)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
@eval function ($initializer)(
::Type{T}, dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(
rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
@eval begin
function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
end
function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
end

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs)
end
function ($initializer)(::Type{T}; kwargs...) where {T <: $NType}
return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs)
end
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType}
return PartialWeightInitializationFunction{T}($initializer, rng, kwargs)
end
function ($initializer)(; kwargs...)
return PartialWeightInitializationFunction{Nothing}(
$initializer, nothing, kwargs)
end
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
end

for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand)
initializer = Symbol(func, tp)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
@eval begin
function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), dims...; kwargs...)
end
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs)
end
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end
function ($initializer)(; kwargs...)
return PartialWeightInitializationFunction{Missing}(
$initializer, nothing, kwargs)
end
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
end
70 changes: 70 additions & 0 deletions src/partial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@concrete struct PartialWeightInitializationFunction{T} <: Function
f <: Function
rng <: Union{Nothing, AbstractRNG}
kwargs
end

function Base.show(
io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T}
print(io, "$(f.f)(")
f.rng !== nothing ? print(io, "$(f.rng), ") : print(io, "rng, ")
if T === Nothing
print(io, "::Type{T}, ")
else
T !== Missing ? print(io, "$(T), ") : nothing
end
print(io, "dims...")
kwargs_str = String[]
for (k, v) in pairs(f.kwargs)
push!(kwargs_str, "$(k)=$(v)")
end
length(kwargs_str) > 0 && print(io, "; ", join(kwargs_str, ", "))
print(io, ")")
end

# ::Type{T} is already specified
function (f::PartialWeightInitializationFunction{T, F, <:AbstractRNG})(
dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(f.rng, T, dims...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T, F, Nothing})(
rng::AbstractRNG; kwargs...) where {T <: Number, F}
return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{T, F, Nothing})(
rng::AbstractRNG, dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(rng, T, dims...; f.kwargs..., kwargs...)
end

# ::Type{T} is not needed
function (f::PartialWeightInitializationFunction{Missing, F, <:AbstractRNG})(
dims::Integer...; kwargs...) where {F}
return f.f(f.rng, dims...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{Missing, F, Nothing})(
rng::AbstractRNG; kwargs...) where {F}
return PartialWeightInitializationFunction{Missing}(
f.f, rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{Missing, F, Nothing})(
rng::AbstractRNG, dims::Integer...; kwargs...) where {F}
return f.f(rng, dims...; f.kwargs..., kwargs...)
end

# ::Type{T} is not specified
function (f::PartialWeightInitializationFunction{Nothing, F, Union{<:AbstractRNG, Nothing}})(
::Type{T}; kwargs...) where {T <: Number, F}
return PartialWeightInitializationFunction{T}(f.f, f.rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{Nothing, F, <:AbstractRNG})(
::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(f.rng, T, dims...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})(
rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: Number, F}
return PartialWeightInitializationFunction{T}(f.f, rng, (; f.kwargs..., kwargs...))
end
function (f::PartialWeightInitializationFunction{Nothing, F, Nothing})(
rng::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T <: Number, F}
return f.f(rng, T, dims...; f.kwargs..., kwargs...)
end
3 changes: 0 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

@inline _default_rng() = Xoshiro(1234)

# This is needed if using `PartialFunctions.$` inside @eval block
@inline __partial_apply(fn, inp) = fn$inp

const NAME_TO_DIST = Dict(
:zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones",
:randn => "random numbers from a standard normal distribution",
Expand Down

0 comments on commit aef8695

Please sign in to comment.