Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move PartialFunctions into a module
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 19, 2024
1 parent a35860d commit 1bde2b3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
3 changes: 1 addition & 2 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ module WeightInitializers

using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using Random: Random, AbstractRNG, Xoshiro, shuffle
using SpecialFunctions: SpecialFunctions, erf, erfinv
using SpecialFunctions: SpecialFunctions, erf, erfinv # Move to Ext in v2.0
using Statistics: Statistics, std

const CRC = ChainRulesCore
Expand Down
14 changes: 6 additions & 8 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,17 +331,16 @@ for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
return PartialWeightInitializationFunction{Nothing}($initializer, rng, kwargs)
return PartialFunction.Partial{Nothing}($initializer, rng, kwargs)
end
function ($initializer)(::Type{T}; kwargs...) where {T <: $NType}
return PartialWeightInitializationFunction{T}($initializer, nothing, kwargs)
return PartialFunction.Partial{T}($initializer, nothing, kwargs)
end
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType}
return PartialWeightInitializationFunction{T}($initializer, rng, kwargs)
return PartialFunction.Partial{T}($initializer, rng, kwargs)
end
function ($initializer)(; kwargs...)
return PartialWeightInitializationFunction{Nothing}(
$initializer, nothing, kwargs)
return PartialFunction.Partial{Nothing}($initializer, nothing, kwargs)
end
end
end
Expand All @@ -362,14 +361,13 @@ for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :rand

# Partial application
function ($initializer)(rng::AbstractRNG; kwargs...)
return PartialWeightInitializationFunction{Missing}($initializer, rng, kwargs)
return PartialFunction.Partial{Missing}($initializer, rng, kwargs)
end
function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T}
throw(ArgumentError(string($initializer) * " doesn't accept a type argument."))
end
function ($initializer)(; kwargs...)
return PartialWeightInitializationFunction{Missing}(
$initializer, nothing, kwargs)
return PartialFunction.Partial{Missing}($initializer, nothing, kwargs)
end
end
end
24 changes: 14 additions & 10 deletions src/partial.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
@concrete struct PartialWeightInitializationFunction{T} <: Function
module PartialFunction

using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Random: AbstractRNG

@concrete struct Partial{T} <: Function
f <: Function
rng <: Union{Nothing, AbstractRNG}
kwargs
end

function Base.show(
io::IO, ::MIME"text/plain", f::PartialWeightInitializationFunction{T}) where {T}
function Base.show(io::IO, ::MIME"text/plain", f::Partial{T}) where {T}
print(io, "$(f.f)(")
if f.rng !== nothing
print(io, "$(nameof(typeof(f.rng)))(...), ")
Expand All @@ -26,22 +31,21 @@ function Base.show(
print(io, ")")
end

function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
args...; kwargs...)
function (f::Partial{<:Union{Nothing, Missing}})(args...; kwargs...)
f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...)
return f.f(f.rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{<:Union{Nothing, Missing}})(
rng::AbstractRNG, args...; kwargs...)
function (f::Partial{<:Union{Nothing, Missing}})(rng::AbstractRNG, args...; kwargs...)
@argcheck f.rng === nothing
return f.f(rng, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(args...; kwargs...) where {T <: Number}
function (f::Partial{T})(args...; kwargs...) where {T <: Number}
f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...)
return f.f(f.rng, T, args...; f.kwargs..., kwargs...)
end
function (f::PartialWeightInitializationFunction{T})(
rng::AbstractRNG, args...; kwargs...) where {T <: Number}
function (f::Partial{T})(rng::AbstractRNG, args...; kwargs...) where {T <: Number}
@argcheck f.rng === nothing
return f.f(rng, T, args...; f.kwargs..., kwargs...)
end

end

0 comments on commit 1bde2b3

Please sign in to comment.