Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: custom partial application implementation #30

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
join_lines_based_on_source = false
always_for_in = true
annotate_untyped_fields_with_any = false
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.9"
version = "0.1.10"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -29,15 +30,16 @@ WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]
[compat]
AMDGPU = "0.9.6"
Aqua = "0.8.7"
ArgCheck = "2.3.0"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
Documenter = "1.5.0"
ExplicitImports = "1.9.0"
GPUArrays = "10.2"
GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
Metal = "1.1.0"
PartialFunctions = "1.2"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.24.0"
Expand Down
6 changes: 3 additions & 3 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
module WeightInitializers

#! format: off
using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using PartialFunctions: :$
using Random: Random, AbstractRNG, Xoshiro, shuffle
using SpecialFunctions: SpecialFunctions, erf, erfinv
using Statistics: Statistics, std
#! format: on

const CRC = ChainRulesCore

include("partial.jl")
include("utils.jl")
include("initializers.jl")
include("autodiff.jl")
Expand Down
73 changes: 49 additions & 24 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
"""
function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=T(1.0)) where {T <: Number}
@assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed"
@argcheck length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed"

rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end])
rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain)))
Expand Down Expand Up @@ -318,33 +318,58 @@ end
for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal,
:truncated_normal, :orthogonal, :sparse_init, :identity_init)
NType = ifelse(initializer === :truncated_normal, Real, Number)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
@eval function ($initializer)(
::Type{T}, dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(
rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
@eval begin
function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), Float32, dims...; kwargs...)
end
function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: $NType}
return $initializer(_default_rng(), T, dims...; kwargs...)
end

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs)
end
function ($initializer)(::Type{T}; kwargs...) where {T <: $NType}
return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs)
end
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType}
return PartialWeightInitializationFunction{T}($initializer, rng, kwargs)
end
function ($initializer)(; kwargs...)
return PartialWeightInitializationFunction{Nothing}(
$initializer, nothing, kwargs)
end
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
end

for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand)
initializer = Symbol(func, tp)
@eval function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), dims...; kwargs...)
end
@eval function ($initializer)(rng::AbstractRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
@eval begin
function ($initializer)(dims::Integer...; kwargs...)
return $initializer(_default_rng(), dims...; kwargs...)
end
function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end
function ($initializer)(
::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs)
end
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end
function ($initializer)(; kwargs...)
return PartialWeightInitializationFunction{Missing}(
$initializer, nothing, kwargs)
end
end
@eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...))
end
47 changes: 47 additions & 0 deletions src/partial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
@concrete struct PartialWeightInitializationFunction{T} <: Function
f <: Function
rng <: Union{Nothing, AbstractRNG}
kwargs
end

function Base.show(
io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T}
print(io, "$(f.f)(")
if f.rng !== nothing
print(io, "$(nameof(typeof(f.rng)))(...), ")
else
print(io, "rng, ")
end
if T === Nothing
print(io, "::Type{T}, ")
else
T !== Missing ? print(io, "$(T), ") : nothing
end
print(io, "dims...")
kwargs_str = String[]
for (k, v) in pairs(f.kwargs)
push!(kwargs_str, "$(k)=$(v)")
end
length(kwargs_str) > 0 && print(io, "; ", join(kwargs_str, ", "))
print(io, ")")
end

function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
args...; kwargs...)
f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...)
return f.f(f.rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
rng::AbstractRNG, args...; kwargs...)
@argcheck f.rng === nothing
return f.f(rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number}
f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...)
return f.f(f.rng, T, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(
rng::AbstractRNG, args...; kwargs...) where {T <: Number}
@argcheck f.rng === nothing
return f.f(rng, T, args...; f.kwargs..., kwargs...)
end
3 changes: 0 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

@inline _default_rng() = Xoshiro(1234)

# This is needed if using `PartialFunctions.$` inside @eval block
@inline __partial_apply(fn, inp) = fn$inp

const NAME_TO_DIST = Dict(
:zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones",
:randn => "random numbers from a standard normal distribution",
Expand Down
37 changes: 37 additions & 0 deletions test/initializers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@ end
@test orthogonal(rng, T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng)
display(cl)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng, T)
display(cl)
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "Orthogonal Closure" begin
cl = orthogonal(;)
display(cl)

# Sizes
@test size(cl(3, 4)) == (3, 4)
Expand Down Expand Up @@ -114,17 +117,22 @@ end
@test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2}

cl = sparse_init(rng; sparsity=0.5)
display(cl)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = sparse_init(rng, T; sparsity=0.5)
display(cl)
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "sparse_init Closure" begin
cl = sparse_init(; sparsity=0.5)
display(cl)

# Sizes
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
Expand Down Expand Up @@ -158,11 +166,14 @@ end
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(init(rng, 4, 2)) == Float32
@test eltype(init(4, 2)) == Float32

# RNG Closure
cl = init(rng)
display(cl)
@test cl(3) isa arrtype{Float32, 1}
@test cl(3, 5) isa arrtype{Float32, 2}
end
Expand All @@ -185,13 +196,28 @@ end
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(init(rng, 4, 2)) == fp
@test eltype(init(4, 2)) == fp

# RNG Closure
cl = init(rng)
display(cl)
@test cl(3) isa arrtype{fp, 1}
@test cl(3, 5) isa arrtype{fp, 2}

# Kwargs closure
cl = init(;)
display(cl)
@test cl(rng, 3) isa arrtype{fp, 1}
@test cl(rng, 3, 5) isa arrtype{fp, 2}

# throw error on type as input
@test_throws ArgumentError init(Float32)
@test_throws ArgumentError init(Float32, 3, 5)
@test_throws ArgumentError init(rng, Float32)
@test_throws ArgumentError init(rng, Float32, 3, 5)
end

@testset "AbstractArray Type: $init $T" for init in [
Expand All @@ -216,12 +242,20 @@ end
@test init(rng, T, 3, 5) isa arrtype{T, 2}

cl = init(rng)
display(cl)
@test cl(T, 3) isa arrtype{T, 1}
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = init(rng, T)
display(cl)
@test cl(3) isa arrtype{T, 1}
@test cl(3, 5) isa arrtype{T, 2}

cl = init(T)
display(cl)
@test cl(3) isa Array{T, 1}
@test cl(3, 5) isa Array{T, 2}
@test cl(rng, 3, 5) isa arrtype{T, 2}
end

@testset "Closure: $init" for init in [
Expand All @@ -233,13 +267,16 @@ end
end

cl = init(;)
display(cl)

# Sizes
@test size(cl(3)) == (3,)
@test size(cl(rng, 3)) == (3,)
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
@test size(cl(3, 4, 5)) == (3, 4, 5)
@test size(cl(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))

const EXTRA_PKGS = String[]

BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" && push!(EXTRA_PKGS, "CUDA")
BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" && push!(EXTRA_PKGS, "AMDGPU")
BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" && push!(EXTRA_PKGS, "Metal")
BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" && push!(EXTRA_PKGS, "oneAPI")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "CUDA")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI")

if !isempty(EXTRA_PKGS)
@info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS
Expand Down
Loading