Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Minor cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 27, 2024
1 parent 665e9f0 commit 7fd4a42
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 79 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ jobs:
- uses: julia-actions/julia-runtest@v1
env:
GROUP: "CPU"
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
version: ['1.9']
version: ['1']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
29 changes: 17 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.7"
version = "0.1.8"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -18,26 +20,29 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
WeightInitializersCUDAExt = "CUDA"

[compat]
Aqua = "0.8"
CUDA = "5"
ChainRulesCore = "1.21"
LinearAlgebra = "1.9"
Aqua = "0.8.7"
ArgCheck = "2.3.0"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
ExplicitImports = "1.6.0"
LinearAlgebra = "1.10"
PartialFunctions = "1.2"
Random = "1.9"
Random = "1.10"
ReTestItems = "1.24.0"
SpecialFunctions = "2"
StableRNGs = "1"
Statistics = "1.9"
Test = "1.9"
julia = "1.9"
Statistics = "1.10"
Test = "1.10"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "CUDA", "Random", "ReTestItems", "StableRNGs", "Statistics", "Test"]
test = ["Aqua", "CUDA", "Documenter", "ExplicitImports", "ReTestItems", "StableRNGs", "Test"]
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# WeightInitializers

[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/WeightInitializers)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

[![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl)
Expand Down
25 changes: 5 additions & 20 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module WeightInitializersCUDAExt

using WeightInitializers, CUDA
using Random
import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init,
orthogonal
using CUDA: CUDA, CURAND
using Random: Random, shuffle
using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply

const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG}

Expand All @@ -21,7 +20,7 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros)
end
end

function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
sparsity::Number, std::Number=T(0.01)) where {T <: Number}
if length(dims) != 2
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
Expand All @@ -36,7 +35,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1)
end

function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
if length(dims) == 1
# Bias initialization
Expand All @@ -62,18 +61,4 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
end
end

for initializer in (:sparse_init, :identity_init)
@eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end

@eval function ($initializer)(rng::AbstractCuRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(
rng::AbstractCuRNG, ::Type{T}; kwargs...) where {T <: Number}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
end
end

end
27 changes: 13 additions & 14 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
module WeightInitializers

using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra
#! format: off
using ChainRulesCore: ChainRulesCore
using GPUArraysCore: GPUArraysCore
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using PartialFunctions: :$
using Random: Random, AbstractRNG, Xoshiro, shuffle
using SpecialFunctions: SpecialFunctions, erf, erfinv
using Statistics: Statistics, std
#! format: on

const CRC = ChainRulesCore

include("utils.jl")
include("initializers.jl")

# Mark the functions as non-differentiable
for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval @non_differentiable $(f)(::Any...)
end
include("autodiff.jl")

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16,
onesC16, randC16, randnC16
export glorot_normal, glorot_uniform
export kaiming_normal, kaiming_uniform
export truncated_normal
export orthogonal
export sparse_init
export identity_init
export truncated_normal, orthogonal, sparse_init, identity_init

end
8 changes: 8 additions & 0 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Mark the functions as non-differentiable
for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval CRC.@non_differentiable $(f)(::Any...)
end
20 changes: 4 additions & 16 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,26 +152,14 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=T(1.0)) where {T <: Number}
@assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed"

if length(dims) == 2
rows, cols = dims
else
rows = prod(dims[1:(end - 1)])
cols = dims[end]
end

if rows < cols
return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain)))
end
rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end])
rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain)))

mat = randn(rng, T, rows, cols)
Q, R = qr(mat)
mat .= Q * sign.(Diagonal(R)) .* T(gain)

if length(dims) > 2
return reshape(mat, dims)
else
return mat
end
return length(dims) > 2 ? reshape(mat, dims) : mat
end

"""
Expand Down Expand Up @@ -296,7 +284,7 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini
5; gain=1.5, shift=(1, 0))
```
"""
function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
if length(dims) == 1
# Bias initialization
Expand Down
19 changes: 5 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@
@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
@inline _nfan(dims::Tuple) = _nfan(dims...)
@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels
_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / 2))
@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / 2))

function _default_rng()
@static if VERSION >= v"1.7"
return Xoshiro(1234)
else
return MersenneTwister(1234)
end
end
@inline _default_rng() = Xoshiro(1234)

# This is needed if using `PartialFunctions.$` inside @eval block
__partial_apply(fn, inp) = fn$inp
@inline __partial_apply(fn, inp) = fn$inp

const NAME_TO_DIST = Dict(
:zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones",
Expand All @@ -26,11 +20,8 @@ const NUM_TO_FPOINT = Dict(

@inline function __funcname(fname::String)
fp = fname[(end - 2):end]
if Symbol(fp) in keys(NUM_TO_FPOINT)
return fname[1:(end - 3)], fp
else
return fname[1:(end - 2)], fname[(end - 1):end]
end
Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp
return fname[1:(end - 2)], fname[(end - 1):end]
end

@inline function __generic_docstring(fname::String)
Expand Down

0 comments on commit 7fd4a42

Please sign in to comment.