Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: LuxDL/WeightInitializers.jl
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.1.7
Choose a base ref
...
head repository: LuxDL/WeightInitializers.jl
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref

Commits on Apr 8, 2024

  1. Bump julia-actions/setup-julia from 1 to 2

    Bumps [julia-actions/setup-julia](https://github.com/julia-actions/setup-julia) from 1 to 2.
    - [Release notes](https://github.com/julia-actions/setup-julia/releases)
    - [Commits](julia-actions/setup-julia@v1...v2)
    
    ---
    updated-dependencies:
    - dependency-name: julia-actions/setup-julia
      dependency-type: direct:production
      update-type: version-update:semver-major
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored Apr 8, 2024
    Copy the full SHA
    edc0dcf View commit details
  2. Merge pull request #24 from LuxDL/dependabot/github_actions/julia-act…

    …ions/setup-julia-2
    
    Bump julia-actions/setup-julia from 1 to 2
    avik-pal authored Apr 8, 2024
    Copy the full SHA
    f5ffbb1 View commit details

Commits on Jun 27, 2024

  1. Run formatter

    avik-pal committed Jun 27, 2024
    Copy the full SHA
    665e9f0 View commit details
  2. Minor cleanups

    avik-pal committed Jun 27, 2024
    Copy the full SHA
    7fd4a42 View commit details
  3. Generalize the code

    avik-pal committed Jun 27, 2024
    Copy the full SHA
    e74b058 View commit details
  4. Finish rewriting the tests

    avik-pal committed Jun 27, 2024
    Copy the full SHA
    4f5b4ea View commit details
  5. Merge pull request #25 from LuxDL/ap/remove_pt

    Cleaning up of the codebase
    avik-pal authored Jun 27, 2024
    2
    Copy the full SHA
    ed1f825 View commit details

Commits on Jul 3, 2024

  1. ci: cleaner ci

    avik-pal committed Jul 3, 2024
    Copy the full SHA
    1c159e7 View commit details
  2. Copy the full SHA
    1dd722c View commit details
  3. feat: support GPUArrays RNG

    avik-pal committed Jul 3, 2024
    Copy the full SHA
    04b9656 View commit details
  4. fix: rand samplers

    avik-pal committed Jul 3, 2024
    Copy the full SHA
    1f05133 View commit details
  5. Copy the full SHA
    fb976d6 View commit details
  6. chore: format suggestion

    Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
    avik-pal and github-actions[bot] committed Jul 3, 2024
    Copy the full SHA
    59ce7d6 View commit details
  7. Copy the full SHA
    4e62456 View commit details
  8. Copy the full SHA
    893f5a0 View commit details
  9. Copy the full SHA
    8174633 View commit details
  10. Copy the full SHA
    f8c1a36 View commit details
  11. test: skip certain RNG tests for cuda/rocm

    avik-pal committed Jul 3, 2024
    2
    Copy the full SHA
    99bc756 View commit details

Commits on Jul 8, 2024

  1. chore: bump crate-ci/typos from 1.22.9 to 1.23.1 (#27)

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.22.9 to 1.23.1.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.22.9...v1.23.1)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
    dependabot[bot] authored Jul 8, 2024
    Copy the full SHA
    d257250 View commit details

Commits on Jul 9, 2024

  1. ci: more robust testing and ci (#28)

    * test: more explicit imports testing
    
    * ci: run only necessary tests
    avik-pal authored Jul 9, 2024
    Copy the full SHA
    d295cd5 View commit details

Commits on Jul 12, 2024

  1. Copy the full SHA
    1a16c64 View commit details
  2. fix: partial application

    avik-pal committed Jul 12, 2024
    Copy the full SHA
    a6674df View commit details
  3. fix: add missing dispatch

    avik-pal committed Jul 12, 2024
    2
    Copy the full SHA
    cbde88d View commit details

Commits on Jul 15, 2024

  1. chore: bump crate-ci/typos from 1.23.1 to 1.23.2

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.1 to 1.23.2.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.23.1...v1.23.2)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-patch
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Jul 15, 2024
    Copy the full SHA
    dd3690a View commit details

Commits on Jul 26, 2024

  1. chore: bump to 1.0

    avik-pal authored Jul 26, 2024
    2
    Copy the full SHA
    c6eeb42 View commit details

Commits on Jul 29, 2024

  1. chore: bump crate-ci/typos from 1.23.2 to 1.23.5

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.2 to 1.23.5.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.23.2...v1.23.5)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-patch
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Jul 29, 2024
    Copy the full SHA
    138ffa7 View commit details

Commits on Aug 9, 2024

  1. chore: bump crate-ci/typos from 1.23.5 to 1.23.6

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.5 to 1.23.6.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.23.5...v1.23.6)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-patch
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Aug 9, 2024
    Copy the full SHA
    79b8543 View commit details
  2. chore: bump compat for AMDGPU in [weakdeps] to 1, (keep existing comp…

    …at) (#34)
    
    Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org>
    github-actions[bot] and CompatHelper Julia authored Aug 9, 2024
    Copy the full SHA
    568cc39 View commit details
  3. chore: update version for release

    avik-pal authored Aug 9, 2024
    2
    Copy the full SHA
    2cfa6f0 View commit details

Commits on Aug 19, 2024

  1. Copy the full SHA
    a35860d View commit details
  2. Copy the full SHA
    1bde2b3 View commit details
  3. Copy the full SHA
    1c450e3 View commit details
  4. Copy the full SHA
    3fa6045 View commit details
  5. test: separate out the testing project file

    avik-pal committed Aug 19, 2024
    2
    Copy the full SHA
    0d95f4a View commit details

Commits on Aug 21, 2024

  1. refactor: move ChainRulesCore into an extension

    avik-pal committed Aug 21, 2024
    2
    Copy the full SHA
    f7a95d8 View commit details

Commits on Aug 26, 2024

  1. chore: bump crate-ci/typos from 1.23.6 to 1.24.1

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.23.6 to 1.24.1.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.23.6...v1.24.1)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Aug 26, 2024
    Copy the full SHA
    72d9f2f View commit details

Commits on Sep 2, 2024

  1. chore: bump crate-ci/typos from 1.24.1 to 1.24.3

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.1 to 1.24.3.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.24.1...v1.24.3)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-patch
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Sep 2, 2024
    Copy the full SHA
    9bab632 View commit details

Commits on Sep 23, 2024

  1. chore: bump crate-ci/typos from 1.24.3 to 1.24.6

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.3 to 1.24.6.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.24.3...v1.24.6)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-patch
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Sep 23, 2024
    Copy the full SHA
    caa4db2 View commit details

Commits on Oct 7, 2024

  1. chore: bump crate-ci/typos from 1.24.6 to 1.25.0

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.24.6 to 1.25.0.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.24.6...v1.25.0)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    dependabot[bot] authored and avik-pal committed Oct 7, 2024
    Copy the full SHA
    6601f20 View commit details

Commits on Oct 8, 2024

  1. ci: run on 1.10 and 1 (#43)

    * ci: run on `1.10` and `1`
    
    * ci: run on `1.10` and `1`
    
    * test: mark truncated normal on Metal as unbroken
    avik-pal authored Oct 8, 2024
    Copy the full SHA
    91ac0a4 View commit details
  2. ci: run buildkite on 1.10 and 1

    avik-pal authored Oct 8, 2024
    Copy the full SHA
    0b81d93 View commit details
  3. chore: bump peter-evans/create-pull-request from 6 to 7 (#40)

    Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 6 to 7.
    - [Release notes](https://github.com/peter-evans/create-pull-request/releases)
    - [Commits](peter-evans/create-pull-request@v6...v7)
    
    ---
    updated-dependencies:
    - dependency-name: peter-evans/create-pull-request
      dependency-type: direct:production
      update-type: version-update:semver-major
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
    dependabot[bot] authored Oct 8, 2024
    Copy the full SHA
    0212028 View commit details

Commits on Oct 14, 2024

  1. chore: bump crate-ci/typos from 1.25.0 to 1.26.0 (#44)

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.25.0 to 1.26.0.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.25.0...v1.26.0)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
    dependabot[bot] authored Oct 14, 2024
    Copy the full SHA
    be8bf62 View commit details

Commits on Oct 18, 2024

  1. chore: bump compat for GPUArrays in [weakdeps] to 11, (keep existing …

    …compat) (#46)
    
    Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org>
    github-actions[bot] and CompatHelper Julia authored Oct 18, 2024
    Copy the full SHA
    4c322b2 View commit details
  2. chore: bump compat for GPUArraysCore to 0.2, (keep existing compat) (#47

    )
    
    Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org>
    Co-authored-by: Avik Pal <avikpal@mit.edu>
    3 people authored Oct 18, 2024
    Copy the full SHA
    a06f1c1 View commit details
  3. chore: bump version for release

    avik-pal authored Oct 18, 2024
    2
    Copy the full SHA
    ecc542a View commit details

Commits on Oct 28, 2024

  1. chore: bump crate-ci/typos from 1.26.0 to 1.26.8 (#49)

    Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.26.0 to 1.26.8.
    - [Release notes](https://github.com/crate-ci/typos/releases)
    - [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
    - [Commits](crate-ci/typos@v1.26.0...v1.26.8)
    
    ---
    updated-dependencies:
    - dependency-name: crate-ci/typos
      dependency-type: direct:production
      update-type: version-update:semver-patch
    ...
    
    Signed-off-by: dependabot[bot] <support@github.com>
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
    dependabot[bot] authored Oct 28, 2024
    Copy the full SHA
    f53a4c0 View commit details

Commits on Nov 4, 2024

  1. docs: update readme

    avik-pal authored Nov 4, 2024
    Copy the full SHA
    b100b29 View commit details
3 changes: 2 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
style = "sciml"
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
join_lines_based_on_source = false
always_for_in = true
annotate_untyped_fields_with_any = false
158 changes: 25 additions & 133 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,134 +1,26 @@
steps:
- group: ":julia: CUDA GPU"
steps:
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
env:
GROUP: "CUDA"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"

# Downstream CUDA Tests
- group: ":telescope: Downstream CUDA"
steps:
- label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + CUDA GPU)"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
command: |
julia --code-coverage=user --color=yes --project -e '
using Pkg
repo = ENV["DOWNSTREAM_TEST_REPO"]
println("--- :julia: Instantiating project")
withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do
Pkg.instantiate()
try
Pkg.develop(repo)
println("+++ :julia: Running tests")
Pkg.test("$(repo)"; coverage=true)
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
@info "Not compatible with this release. No problem." exception=err
exit(0)
end
end
println("+++ :julia: Finished Downstream Test")'
agents:
queue: "juliagpu"
cuda: "*"
env:
GROUP: "CUDA"
DOWNSTREAM_TEST_REPO: "{{matrix.repo}}"
if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"
repo:
- "Lux"
- "Boltz"

# Downstream AMDGPU Tests
- group: ":telescope: Downstream AMD GPU"
steps:
- label: ":julia: {{matrix.repo}} (Julia {{matrix.julia}} + AMD GPU)"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
command: |
julia --code-coverage=user --color=yes --project -e '
using Pkg
repo = ENV["DOWNSTREAM_TEST_REPO"]
println("--- :julia: Instantiating project")
withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0) do
Pkg.instantiate()
try
Pkg.develop(repo)
println("+++ :julia: Running tests")
Pkg.test("$(repo)"; coverage=true)
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
@info "Not compatible with this release. No problem." exception=err
exit(0)
end
end
println("+++ :julia: Finished Downstream Test")'
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
env:
GROUP: "AMDGPU"
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
DOWNSTREAM_TEST_REPO: "{{matrix.repo}}"
if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1"
repo:
- "Lux"
- "Boltz"

env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw=="


- label: "Triggering Pipelines (Pull Request)"
if: "build.pull_request.base_branch == 'main'"
agents:
queue: "juliagpu"
plugins:
- monebag/monorepo-diff#v2.5.9:
diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT"
interpolation: false
watch:
- path:
- "src/"
- "ext/"
- "test/"
- "Project.toml"
- ".buildkite/"
config:
command: "buildkite-agent pipeline upload .buildkite/testing.yml"
agents:
queue: "juliagpu"

- label: "Triggering Pipelines (Main Branch / Tag)"
if: build.branch == "main" || build.tag != null
agents:
queue: "juliagpu"
command: "buildkite-agent pipeline upload .buildkite/testing.yml"
13 changes: 13 additions & 0 deletions .buildkite/scripts/diff.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
set -ueo pipefail

# Script to output the diff where the branch was created
# Usage: ./diff.sh $BUILDKITE_COMMIT

COMMIT_HASH=$1
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

BRANCH_POINT_COMMIT=$($SCRIPT_DIR/find_branch_point.sh "$COMMIT_HASH")
echo >&2 "Cannot find latest build. Running diff against: $BRANCH_POINT_COMMIT"
diff=$(git diff --name-only "$BRANCH_POINT_COMMIT")
echo "$diff"
25 changes: 25 additions & 0 deletions .buildkite/scripts/downstream.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Pkg

repo = ARGS[1]
if contains(repo, "#")
repo, group = split(repo, "#")
else
group = ARGS[2]
end

println("--- :julia: Instantiating project")
withenv("JULIA_PKG_PRECOMPILE_AUTO" => 0, "GROUP" => group, "BACKEND_GROUP" => group) do
Pkg.instantiate()

try
Pkg.develop(repo)
println("+++ :julia: Running tests")
Pkg.test("$(repo)"; coverage=true)
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
@info "Not compatible with this release. No problem." exception=err
exit(0)
end
end

println("+++ :julia: Finished Downstream Test")
6 changes: 6 additions & 0 deletions .buildkite/scripts/find_branch_point.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
set -ue

diff -u <(git rev-list --first-parent "$1") \
<(git rev-list --first-parent main) | \
sed -ne 's/^ //p' | head -1
163 changes: 163 additions & 0 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
steps:
- group: ":julia: CUDA GPU"
steps:
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
env:
BACKEND_GROUP: "CUDA"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 60
matrix:
setup:
julia:
- "1.10"
- "1"

- group: ":telescope: Downstream CUDA"
steps:
- label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA"
agents:
queue: "juliagpu"
cuda: "*"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test"
timeout_in_minutes: 240
matrix:
setup:
repo:
- "Boltz"
- "Lux"

- group: ":julia: AMD GPU"
steps:
- label: ":julia: Julia: {{matrix.julia}} + AMD GPU"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
BACKEND_GROUP: "AMDGPU"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 240
matrix:
setup:
julia:
- "1.10"
- "1"

- group: ":telescope: Downstream AMD GPU"
steps:
- label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU"
agents:
queue: "juliagpu"
rocm: "*"
rocmgpu: "*"
env:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test"
timeout_in_minutes: 60
matrix:
setup:
repo:
- "Boltz"
- "Lux"

- group: ":julia: Metal GPU"
steps:
- label: ":julia: Julia: {{matrix.julia}} + Metal"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
# - JuliaCI/julia-coverage#v1:
# codecov: true
# dirs:
# - src
# - ext
agents:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
env:
BACKEND_GROUP: "Metal"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
matrix:
setup:
julia:
- "1.10"
- "1"

- group: ":julia: oneAPI GPU"
steps:
- label: ":julia: Julia: {{matrix.julia}} + oneAPI"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
env:
BACKEND_GROUP: "oneAPI"
agents:
queue: "juliagpu"
intel: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
matrix:
setup:
julia:
- "1.10"
- "1"

env:
SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw=="
134 changes: 130 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -3,25 +3,40 @@ on:
pull_request:
branches:
- main
paths:
- "src/**"
- "ext/**"
- "test/**"
- "Project.toml"
- ".github/workflows/CI.yml"
push:
branches:
- main

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ubuntu-latest
ci:
name: Julia ${{ matrix.version }} - ${{ matrix.os }}
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- "min"
- "1"
os:
- ubuntu-latest
- macos-latest
- windows-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
- uses: actions/cache@v4
@@ -36,8 +51,87 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
- uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true
fail_ci_if_error: true

downstream:
name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }}
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }}
runs-on: ${{ matrix.os }}
timeout-minutes: 240
env:
GROUP: ${{ matrix.package.group }}
strategy:
fail-fast: false
matrix:
julia-version: ["1"]
os: [ubuntu-latest]
package:
- { user: LuxDL, repo: Lux.jl, group: CPU }
- { user: LuxDL, repo: Boltz.jl, group: CPU }
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.julia-version }}
arch: x64
- uses: julia-actions/julia-buildpkg@v1
- name: Clone Downstream
uses: actions/checkout@v4
with:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --code-coverage=user --color=yes --project=downstream {0}
run: |
using Pkg
try
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
Pkg.test(; coverage=true) # resolver may fail with test time deps
catch err
err isa Pkg.Resolve.ResolverError || rethrow()
# If we can't resolve that means this is incompatible by SemVer and this is fine
# It means we marked this as a breaking change, so we don't need to worry about
# Mistakenly introducing a breaking change, as we have intentionally made one
@info "Not compatible with this release. No problem." exception=err
exit(0) # Exit immediately, as a success
end
env:
GROUP: "CPU"
GROUP: ${{ matrix.package.group }}
BACKEND_GROUP: ${{ matrix.package.group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true
fail_ci_if_error: true

downgrade:
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }}
name: Downgrade Julia ${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version: ["1.10"]
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
- uses: julia-actions/julia-downgrade-compat@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
@@ -47,3 +141,35 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true
fail_ci_if_error: true

invalidations:
# Only run on PRs to the default branch.
# In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch
if: github.base_ref == github.event.repository.default_branch
runs-on: ubuntu-latest
steps:
- uses: julia-actions/setup-julia@v2
with:
version: "1"
- uses: actions/checkout@v4
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-invalidations@v1
id: invs_pr

- uses: actions/checkout@v4
with:
ref: ${{ github.event.repository.default_branch }}
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-invalidations@v1
id: invs_default

- name: Report invalidation counts
run: |
echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY
echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY
- name: Check if the PR does increase number of invalidations
if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total
run: exit 1

env:
BACKEND_GROUP: "CPU"
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ jobs:
run: which julia
continue-on-error: true
- name: Install Julia, but only if it is not already available in the PATH
uses: julia-actions/setup-julia@v1
uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: ${{ runner.arch }}
41 changes: 0 additions & 41 deletions .github/workflows/Downgrade.yml

This file was deleted.

68 changes: 0 additions & 68 deletions .github/workflows/Downstream.yml

This file was deleted.

40 changes: 0 additions & 40 deletions .github/workflows/FormatCheck.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/FormatPR.yml
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ jobs:
# https://github.com/peter-evans/create-pull-request#reference-example
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v6
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Format .jl files
40 changes: 0 additions & 40 deletions .github/workflows/Invalidations.yml

This file was deleted.

19 changes: 19 additions & 0 deletions .github/workflows/QualityCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Code Quality Check

on: [pull_request]

jobs:
code-style:
name: Format Suggestions
runs-on: ubuntu-latest
steps:
- uses: julia-actions/julia-format@v3

typos-check:
name: Spell Check with Typos
runs-on: ubuntu-latest
steps:
- name: Checkout Actions Repository
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/typos@v1.26.8
2 changes: 2 additions & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[default.extend-words]
nin = "nin"
57 changes: 29 additions & 28 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.7"
version = "1.0.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
WeightInitializersCUDAExt = "CUDA"
WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"]
WeightInitializersCUDAExt = ["CUDA", "GPUArrays"]
WeightInitializersChainRulesCoreExt = "ChainRulesCore"
WeightInitializersGPUArraysExt = "GPUArrays"
WeightInitializersMetalExt = ["Metal", "GPUArrays"]
WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]

[compat]
Aqua = "0.8"
CUDA = "5"
ChainRulesCore = "1.21"
LinearAlgebra = "1.9"
PartialFunctions = "1.2"
PrecompileTools = "1.2"
Random = "1.9"
SpecialFunctions = "2"
StableRNGs = "1"
Statistics = "1.9"
Test = "1.9"
julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "StableRNGs", "Random", "Statistics", "CUDA"]
AMDGPU = "0.9.6, 1"
ArgCheck = "2.3.0"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
GPUArrays = "10.2, 11"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
Metal = "1.3.0"
Random = "1.10"
SpecialFunctions = "2.4"
Statistics = "1.10"
julia = "1.10"
oneAPI = "1.5.0"
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# WeightInitializers

[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/WeightInitializers)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

[![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl)
[![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml)
[![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl)
[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers)

[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

> [!WARNING]
> Package has been moved to a subdirectory in Lux https://github.com/LuxDL/Lux.jl/tree/main/lib/
This package is a light dependency providing common weight initialization schemes for deep
learning models.

38 changes: 38 additions & 0 deletions ext/WeightInitializersAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module WeightInitializersAMDGPUExt

using AMDGPU: AMDGPU, ROCArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: DeviceAgnostic

function DeviceAgnostic.zeros(
::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.zeros(T, dims...)
end
function DeviceAgnostic.ones(
::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.ones(T, dims...)
end

function DeviceAgnostic.zeros(
::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.zeros(T, dims...)
end
function DeviceAgnostic.ones(
::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.ones(T, dims...)
end
function DeviceAgnostic.rand(
rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = ROCArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function DeviceAgnostic.randn(
rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = ROCArray{T}(undef, dims...)
Random.randn!(rng, y)
return y
end

end
95 changes: 28 additions & 67 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,79 +1,40 @@
module WeightInitializersCUDAExt

using WeightInitializers, CUDA
using Random
import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init,
orthogonal
using CUDA: CUDA, CURAND, CuArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: DeviceAgnostic

const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG}

for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros)
name = Symbol(fname, T)
TP = NUM_TO_FPOINT[Symbol(T)]
@eval begin
function WeightInitializers.$(name)(rng::AbstractCuRNG, dims::Integer...; kwargs...)
return CUDA.$(fname)($TP, dims...; kwargs...)
end
end

@eval function WeightInitializers.$(name)(rng::AbstractCuRNG; kwargs...)
return __partial_apply($name, (rng, (; kwargs...)))
end
function DeviceAgnostic.zeros(
::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.zeros(T, dims...)
end

function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
sparsity::Number, std::Number=T(0.01)) where {T <: Number}
if length(dims) != 2
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
end

rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)
sparse_array = randn(rng, T, dims...) .* T(std)
sparse_array[1:num_zeros, :] .= CUDA.zero(T)

return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1)
function DeviceAgnostic.ones(
::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.ones(T, dims...)
end

function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
if length(dims) == 1
# Bias initialization
return CUDA.zeros(T, dims...)
elseif length(dims) == 2
# Matrix multiplication
rows, cols = dims
mat = CUDA.zeros(T, rows, cols)
diag_indices = 1:min(rows, cols)
CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain))
return CUDA.circshift(mat, shift)
else
# Convolution or more dimensions
nin, nout = dims[end - 1], dims[end]
centers = map(d -> cld(d, 2), dims[1:(end - 2)])
weights = CUDA.zeros(T, dims...)
#we should really find a better way to do this
CUDA.@allowscalar for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = T(gain)
end
return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift))
end
function DeviceAgnostic.zeros(
::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.zeros(T, dims...)
end

for initializer in (:sparse_init, :identity_init)
@eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...)
return $initializer(rng, Float32, dims...; kwargs...)
end

@eval function ($initializer)(rng::AbstractCuRNG; kwargs...)
return __partial_apply($initializer, (rng, (; kwargs...)))
end
@eval function ($initializer)(rng::AbstractCuRNG,
::Type{T}; kwargs...) where {T <: Number}
return __partial_apply($initializer, ((rng, T), (; kwargs...)))
end
function DeviceAgnostic.ones(
::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.ones(T, dims...)
end
function DeviceAgnostic.rand(
rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = CuArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function DeviceAgnostic.randn(
rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = CuArray{T}(undef, dims...)
Random.randn!(rng, y)
return y
end

end
18 changes: 18 additions & 0 deletions ext/WeightInitializersChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module WeightInitializersChainRulesCoreExt

using ChainRulesCore: @non_differentiable
using WeightInitializers: WeightInitializers, DeviceAgnostic

for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval @non_differentiable WeightInitializers.$(f)(::Any...)
end

for f in (:zeros, :ones, :rand, :randn)
@eval @non_differentiable DeviceAgnostic.$(f)(::Any...)
end

end
24 changes: 24 additions & 0 deletions ext/WeightInitializersGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module WeightInitializersGPUArraysExt

using GPUArrays: RNG
using WeightInitializers: DeviceAgnostic

for f in (:zeros, :ones, :rand, :randn)
@eval function DeviceAgnostic.$(f)(
rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return DeviceAgnostic.$(f)(rng, rng.state, T, dims...)
end
end

## Certain backends don't support sampling Complex numbers, so we avoid hitting those
## dispatches
for f in (:rand, :randn)
@eval function DeviceAgnostic.$(f)(
rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
real_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...)
imag_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...)
return Complex{T}.(real_part, imag_part)
end
end

end
29 changes: 29 additions & 0 deletions ext/WeightInitializersMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module WeightInitializersMetalExt

using Metal: Metal, MtlArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: DeviceAgnostic

function DeviceAgnostic.zeros(
::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
return Metal.zeros(T, dims...)
end
function DeviceAgnostic.ones(
::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
return Metal.ones(T, dims...)
end
function DeviceAgnostic.rand(
rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = MtlArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function DeviceAgnostic.randn(
rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = MtlArray{T}(undef, dims...)
Random.randn!(rng, y)
return y
end

end
29 changes: 29 additions & 0 deletions ext/WeightInitializersoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module WeightInitializersoneAPIExt

using oneAPI: oneAPI, oneArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: DeviceAgnostic

function DeviceAgnostic.zeros(
::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
return oneAPI.zeros(T, dims...)
end
function DeviceAgnostic.ones(
::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
return oneAPI.ones(T, dims...)
end
function DeviceAgnostic.rand(
rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = oneArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
function DeviceAgnostic.randn(
rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = oneArray{T}(undef, dims...)
Random.randn!(rng, y)
return y
end

end
56 changes: 8 additions & 48 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,62 +1,22 @@
module WeightInitializers

import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics,
LinearAlgebra
end
using ArgCheck: @argcheck
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using Random: Random, AbstractRNG, shuffle
using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0
using Statistics: Statistics, std

include("partial.jl")
include("utils.jl")
include("initializers.jl")

# Mark the functions as non-differentiable
for f in [
:zeros64,
:ones64,
:rand64,
:randn64,
:zeros32,
:ones32,
:rand32,
:randn32,
:zeros16,
:ones16,
:rand16,
:randn16,
:zerosC64,
:onesC64,
:randC64,
:randnC64,
:zerosC32,
:onesC32,
:randC32,
:randnC32,
:zerosC16,
:onesC16,
:randC16,
:randnC16,
:glorot_normal,
:glorot_uniform,
:kaiming_normal,
:kaiming_uniform,
:truncated_normal,
:orthogonal,
:sparse_init,
:identity_init
]
@eval @non_differentiable $(f)(::Any...)
end

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16,
onesC16, randC16, randnC16
export glorot_normal, glorot_uniform
export kaiming_normal, kaiming_uniform
export truncated_normal
export orthogonal
export sparse_init
export identity_init
export truncated_normal, orthogonal, sparse_init, identity_init

end
319 changes: 163 additions & 156 deletions src/initializers.jl

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions src/partial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module PartialFunction

using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Random: AbstractRNG

@concrete struct Partial{T} <: Function
f <: Function
rng <: Union{Nothing, AbstractRNG}
kwargs
end

function Base.show(io::IO, ::MIME"text/plain", f::Partial{T}) where {T}
print(io, "$(f.f)(")
if f.rng !== nothing
print(io, "$(nameof(typeof(f.rng)))(...), ")
else
print(io, "rng, ")
end
if T === Nothing
print(io, "::Type{T}, ")
else
T !== Missing ? print(io, "$(T), ") : nothing
end
print(io, "dims...")
kwargs_str = String[]
for (k, v) in pairs(f.kwargs)
push!(kwargs_str, "$(k)=$(v)")
end
length(kwargs_str) > 0 && print(io, "; ", join(kwargs_str, ", "))
print(io, ")")
end

function (f::Partial{<:Union{Nothing, Missing}})(args...; kwargs...)
f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...)
return f.f(f.rng, args...; f.kwargs..., kwargs...)
end
function (f::Partial{<:Union{Nothing, Missing}})(rng::AbstractRNG, args...; kwargs...)
@argcheck f.rng === nothing
return f.f(rng, args...; f.kwargs..., kwargs...)
end
function (f::Partial{T})(args...; kwargs...) where {T <: Number}
f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...)
return f.f(f.rng, T, args...; f.kwargs..., kwargs...)
end
function (f::Partial{T})(rng::AbstractRNG, args...; kwargs...) where {T <: Number}
@argcheck f.rng === nothing
return f.f(rng, T, args...; f.kwargs..., kwargs...)
end

end
93 changes: 63 additions & 30 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,78 @@
@inline _nfan() = 1, 1 # fan_in, fan_out
@inline _nfan(n) = 1, n # A vector is treated as a n×1 matrix
@inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
@inline _nfan(dims::Tuple) = _nfan(dims...)
@inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels
_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / 2))

function _default_rng()
@static if VERSION >= v"1.7"
return Xoshiro(1234)
else
return MersenneTwister(1234)
end
end
module Utils

using Random: Xoshiro
using SpecialFunctions: erf

nfan() = 1, 1 # fan_in, fan_out
nfan(n) = 1, n # A vector is treated as a n×1 matrix
nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
nfan(dims::Tuple) = nfan(dims...)
nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels

norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / 2))) # erf often doesn't respect the type

# This is needed if using `PartialFunctions.$` inside @eval block
__partial_apply(fn, inp) = fn$inp
default_rng() = Xoshiro(1234)

const NAME_TO_DIST = Dict(:zeros => "an AbstractArray of zeros",
:ones => "an AbstractArray of ones",
#! format: off
const NAME_TO_DIST = Dict(
:zeros => "an AbstractArray of zeros",
:ones => "an AbstractArray of ones",
:randn => "random numbers from a standard normal distribution",
:rand => "random numbers from a uniform distribution")
const NUM_TO_FPOINT = Dict(Symbol(16) => Float16, Symbol(32) => Float32,
Symbol(64) => Float64, :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64)
:rand => "random numbers from a uniform distribution"
)
const NUM_TO_FPOINT = Dict(
Symbol(16) => Float16,
Symbol(32) => Float32,
Symbol(64) => Float64,
:C16 => ComplexF16,
:C32 => ComplexF32,
:C64 => ComplexF64
)
#! format: on

@inline function __funcname(fname::String)
function function_name(fname::String)
fp = fname[(end - 2):end]
if Symbol(fp) in keys(NUM_TO_FPOINT)
return fname[1:(end - 3)], fp
else
return fname[1:(end - 2)], fname[(end - 1):end]
end
Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp
return fname[1:(end - 2)], fname[(end - 1):end]
end

@inline function __generic_docstring(fname::String)
funcname, fp = __funcname(fname)
function generic_docstring(fname::String)
funcname, fp = function_name(fname)
name = NAME_TO_DIST[Symbol(funcname)]
dist_type = NUM_TO_FPOINT[Symbol(fp)]
return """
$fname([::AbstractRNG=_default_rng()], size...;
$fname([::AbstractRNG=Utils.default_rng()], size...;
kwargs...) -> AbstractArray{$(dist_type), length(size)}
Return an `AbstractArray{$(dist_type)}` of the given `size` containing $(name).
"""
end

end

module DeviceAgnostic

using Random: AbstractRNG

# Helpers for device agnostic initializers
function zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return Base.zeros(T, dims...)
end
ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} = Base.ones(T, dims...)
function rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return Base.rand(rng, T, args...)
end
function randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number}
return Base.randn(rng, T, args...)
end

## Certain backends don't support sampling Complex numbers, so we avoid hitting those
## dispatches
for f in (:rand, :randn)
@eval function $(f)(
rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
return Complex{T}.($(f)(rng, T, args...), $(f)(rng, T, args...))
end
end

end
31 changes: 31 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Aqua = "0.8.7"
Documenter = "1.5.0"
ExplicitImports = "1.9.0"
GPUArrays = "10.2"
GPUArraysCore = "0.1.6"
Hwloc = "3.3"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.24.0"
StableRNGs = "1"
Statistics = "1.10"
Test = "1.10"
350 changes: 350 additions & 0 deletions test/initializers_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
@testitem "Warning: truncated_normal" begin
@test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \
the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0)
end

@testitem "Identity Initialization" begin
@testset "Non-identity sizes" begin
@test identity_init(2, 3)[:, end] == zeros(Float32, 2)
@test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2)
@test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3)
@test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3)
@test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3)
end
end

@testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin
using GPUArraysCore, LinearAlgebra

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES
# A matrix of dim = (m,n) with m > n should produce a QR decomposition.
# In the other case, the transpose should be taken to compute the QR decomposition.
if backend == "oneapi" || backend == "metal" # `qr` not implemented
@test_broken orthogonal(rng, 3, 5) isa arrtype{Float32, 2}
continue
end

for (rows, cols) in [(5, 3), (3, 5)]
v = orthogonal(rng, rows, cols)
GPUArraysCore.@allowscalar rows < cols ? (@test v * v' I(rows)) :
(@test v' * v I(cols))
end

for mat in [(3, 4, 5), (2, 2, 5)]
v = orthogonal(rng, mat...)
cols = mat[end]
rows = div(prod(mat), cols)
v = reshape(v, (rows, cols))
GPUArraysCore.@allowscalar rows < cols ? (@test v * v' I(rows)) :
(@test v' * v I(cols))
end

@testset "Orthogonal Types $T" for T in (Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T
@test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T
end

@testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test orthogonal(rng, T, 3, 5) isa AbstractArray{T, 2}
@test orthogonal(rng, T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng)
display(cl)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng, T)
display(cl)
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "Orthogonal Closure" begin
cl = orthogonal(;)
display(cl)

# Sizes
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
@test size(cl(3, 4, 5)) == (3, 4, 5)
@test size(cl(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
end
end
end

@testitem "Sparse Initialization" setup=[SharedTestSetup] begin
using Statistics

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES
# sparse_init should yield an error for non 2-d dimensions
# sparse_init should yield no zero elements if sparsity < 0
# sparse_init should yield all zero elements if sparsity > 1
# sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for
# other sparsity values
# sparse_init should yield a kernel in its non-zero elements consistent with the std
# parameter

@test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1)
@test_throws ArgumentError sparse_init(3, sparsity=0.1)
v = sparse_init(100, 100; sparsity=-0.1)
@test sum(v .== 0) == 0
v = sparse_init(100, 100; sparsity=1.1)
@test sum(v .== 0) == length(v)

for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)]
expected_zeros = ceil(Integer, n_in * sparsity)
v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ)
@test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out])
@test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ
end

@testset "sparse_init Type $T" for T in (Float16, Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T
end

@testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64)
!supports_fp64 && T == Float64 && continue

@test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2}
@test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2}

cl = sparse_init(rng; sparsity=0.5)
display(cl)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = sparse_init(rng, T; sparsity=0.5)
display(cl)
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "sparse_init Closure" begin
cl = sparse_init(; sparsity=0.5)
display(cl)

# Sizes
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
end
end
end

@testitem "Basic Initializations" setup=[SharedTestSetup] begin
using LinearAlgebra, Statistics

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES
@testset "Sizes and Types: $init" for init in [
zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal,
glorot_uniform, glorot_normal, truncated_normal, identity_init]
!supports_fp64 &&
(init === zeros32 ||
init === ones32 ||
init === rand32 ||
init === randn32) &&
continue

if backend == "oneapi" && init === truncated_normal
@test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented
continue
end

# Sizes
@test size(init(3)) == (3,)
@test size(init(rng, 3)) == (3,)
@test size(init(3, 4)) == (3, 4)
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(init(rng, 4, 2)) == Float32
@test eltype(init(4, 2)) == Float32

# RNG Closure
cl = init(rng)
display(cl)
@test cl(3) isa arrtype{Float32, 1}
@test cl(3, 5) isa arrtype{Float32, 2}
end

@testset "Sizes and Types: $init" for (init, fp) in [
(zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32),
(zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64),
(ones16, Float16), (onesC16, ComplexF16), (ones32, Float32),
(onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64),
(rand16, Float16), (randC16, ComplexF16), (rand32, Float32),
(randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64),
(randn16, Float16), (randnC16, ComplexF16), (randn32, Float32),
(randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)]
!supports_fp64 && (fp == Float64 || fp == ComplexF64) && continue

# Sizes
@test size(init(3)) == (3,)
@test size(init(rng, 3)) == (3,)
@test size(init(3, 4)) == (3, 4)
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(init(rng, 4, 2)) == fp
@test eltype(init(4, 2)) == fp

# RNG Closure
cl = init(rng)
display(cl)
@test cl(3) isa arrtype{fp, 1}
@test cl(3, 5) isa arrtype{fp, 2}

# Kwargs closure
cl = init(;)
display(cl)
@test cl(rng, 3) isa arrtype{fp, 1}
@test cl(rng, 3, 5) isa arrtype{fp, 2}

# throw error on type as input
@test_throws ArgumentError init(Float32)
@test_throws ArgumentError init(Float32, 3, 5)
@test_throws ArgumentError init(rng, Float32)
@test_throws ArgumentError init(rng, Float32, 3, 5)
end

@testset "AbstractArray Type: $init $T" for init in [
kaiming_uniform, kaiming_normal, glorot_uniform,
glorot_normal, truncated_normal, identity_init],
T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)

!supports_fp64 && (T == Float64 || T == ComplexF64) && continue

init === truncated_normal && !(T <: Real) && continue

if backend == "oneapi" && init === truncated_normal && T == Float32
@test_broken init(rng, T, 3) isa AbstractArray{T, 1} # `erfinv` not implemented
continue
end

@test init(T, 3) isa AbstractArray{T, 1}
@test init(rng, T, 3) isa arrtype{T, 1}
@test init(T, 3, 5) isa AbstractArray{T, 2}
@test init(rng, T, 3, 5) isa arrtype{T, 2}

cl = init(rng)
display(cl)
@test cl(T, 3) isa arrtype{T, 1}
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = init(rng, T)
display(cl)
@test cl(3) isa arrtype{T, 1}
@test cl(3, 5) isa arrtype{T, 2}

cl = init(T)
display(cl)
@test cl(3) isa Array{T, 1}
@test cl(3, 5) isa Array{T, 2}
@test cl(rng, 3, 5) isa arrtype{T, 2}
end

@testset "Closure: $init" for init in [
kaiming_uniform, kaiming_normal, glorot_uniform,
glorot_normal, truncated_normal, identity_init]
if backend == "oneapi" && init === truncated_normal
@test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented
continue
end

cl = init(;)
display(cl)

# Sizes
@test size(cl(3)) == (3,)
@test size(cl(rng, 3)) == (3,)
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
@test size(cl(3, 4, 5)) == (3, 4, 5)
@test size(cl(rng, 3, 4, 5)) == (3, 4, 5)

# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
end

@testset "Kwargs types" for T in (
Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
!supports_fp64 && (T == Float64 || T == ComplexF64) && continue

if (T <: Real)
@test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T
@test eltype(orthogonal(T, 2, 5; gain=1.0)) == T
end
@test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T
@test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T
@test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T
@test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T
@test eltype(identity_init(T, 2, 5; gain=1.0)) == T
@test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T
end

@testset "kaiming" begin
# kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)]
# and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = kaiming_uniform(rng, n_in, n_out)
σ2 = sqrt(6 / n_out)
@test -1σ2 < minimum(v) < -0.9σ2
@test 0.9σ2 < maximum(v) < 1σ2

v = kaiming_normal(rng, n_in, n_out)
σ2 = sqrt(2 / n_out)

if (backend == "cuda" || backend == "amdgpu") && rng isa GPUArrays.RNG
@test_broken 0.9σ2 < std(v) < 1.1σ2
else
@test 0.9σ2 < std(v) < 1.1σ2
end
end
# Type
@test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32
@test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32
end

@testset "glorot: $init" for init in [glorot_uniform, glorot_normal]
# glorot_uniform and glorot_normal should both yield a kernel with
# variance ≈ 2/(fan_in + fan_out)
for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
v = init(dims...)
fan_in, fan_out = WeightInitializers.Utils.nfan(dims...)
σ2 = 2 / (fan_in + fan_out)
@test 0.9σ2 < var(v) < 1.1σ2
end
@test eltype(init(3, 4; gain=1.5)) == Float32
end

@testset "orthogonal" begin
# A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition.
for (rows, cols) in [(5, 3), (3, 5)]
v = orthogonal(rows, cols)
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
end
for mat in [(3, 4, 5), (2, 2, 5)]
v = orthogonal(mat...)
cols = mat[end]
rows = div(prod(mat), cols)
v = reshape(v, (rows, cols))
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
end
@test eltype(orthogonal(3, 4; gain=1.5)) == Float32
end
end
end
33 changes: 33 additions & 0 deletions test/qa_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
@testitem "Aqua: Quality Assurance" begin
using Aqua

Aqua.test_all(WeightInitializers; ambiguities=false)
Aqua.test_ambiguities(WeightInitializers; recursive=false)
end

@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] begin
using ExplicitImports

@test check_no_implicit_imports(WeightInitializers) === nothing
@test check_no_stale_explicit_imports(WeightInitializers) === nothing
@test check_no_self_qualified_accesses(WeightInitializers) === nothing
@test check_all_explicit_imports_via_owners(WeightInitializers) === nothing
@test check_all_qualified_accesses_via_owners(WeightInitializers) === nothing
@test_broken check_all_explicit_imports_are_public(WeightInitializers) === nothing # mostly upstream problems

try # FIXME: Soft fail for now
acc = check_all_qualified_accesses_are_public(WeightInitializers)
@test acc === nothing
catch
@test_broken check_all_qualified_accesses_are_public(WeightInitializers) === nothing
end
end

@testitem "doctests: Quality Assurance" begin
using Documenter

doctestexpr = :(using Random, WeightInitializers)

DocMeta.setdocmeta!(WeightInitializers, :DocTestSetup, doctestexpr; recursive=true)
doctest(WeightInitializers; manual=false)
end
303 changes: 23 additions & 280 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,287 +1,30 @@
using Aqua
using WeightInitializers, Test, Statistics
using StableRNGs, Random, CUDA, LinearAlgebra
using Pkg, ReTestItems, WeightInitializers
using InteractiveUtils, Hwloc

CUDA.allowscalar(false)
@info sprint(versioninfo)

const GROUP = get(ENV, "GROUP", "All")
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))

@testset "WeightInitializers.jl Tests" begin
rngs_arrtypes = []
const EXTRA_PKGS = String[]

if GROUP == "All" || GROUP == "CPU"
append!(rngs_arrtypes,
[(StableRNG(12345), AbstractArray), (Random.default_rng(), AbstractArray)])
end
(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "CUDA")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal")
(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI")

if GROUP == "All" || GROUP == "CUDA"
append!(rngs_arrtypes, [(CUDA.default_rng(), CuArray)])
end

@testset "_nfan" begin
# Fallback
@test WeightInitializers._nfan() == (1, 1)
# Vector
@test WeightInitializers._nfan(4) == (1, 4)
# Matrix
@test WeightInitializers._nfan(4, 5) == (5, 4)
# Tuple
@test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6)
# Convolution
@test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6)
end

@testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes
@testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32,
kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal,
truncated_normal, identity_init
]
# Sizes
@test size(init(3)) == (3,)
@test size(init(rng, 3)) == (3,)
@test size(init(3, 4)) == (3, 4)
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)
# Type
@test eltype(init(rng, 4, 2)) == Float32
@test eltype(init(4, 2)) == Float32
# RNG Closure
cl = init(rng)
@test cl(3) isa arrtype{Float32, 1}
@test cl(3, 5) isa arrtype{Float32, 2}
end

@testset "Sizes and Types: $init" for (init, fp) in [(zeros16, Float16),
(zerosC16, ComplexF16), (zeros32, Float32), (zerosC32, ComplexF32),
(zeros64, Float64), (zerosC64, ComplexF64), (ones16, Float16),
(onesC16, ComplexF16), (ones32, Float32), (onesC32, ComplexF32),
(ones64, Float64), (onesC64, ComplexF64), (rand16, Float16),
(randC16, ComplexF16), (rand32, Float32), (randC32, ComplexF32),
(rand64, Float64), (randC64, ComplexF64), (randn16, Float16),
(randnC16, ComplexF16), (randn32, Float32), (randnC32, ComplexF32),
(randn64, Float64), (randnC64, ComplexF64)]
# Sizes
@test size(init(3)) == (3,)
@test size(init(rng, 3)) == (3,)
@test size(init(3, 4)) == (3, 4)
@test size(init(rng, 3, 4)) == (3, 4)
@test size(init(3, 4, 5)) == (3, 4, 5)
@test size(init(rng, 3, 4, 5)) == (3, 4, 5)
# Type
@test eltype(init(rng, 4, 2)) == fp
@test eltype(init(4, 2)) == fp
# RNG Closure
cl = init(rng)
@test cl(3) isa arrtype{fp, 1}
@test cl(3, 5) isa arrtype{fp, 2}
end

@testset "AbstractArray Type: $init $T" for init in [kaiming_uniform,
kaiming_normal,
glorot_uniform, glorot_normal, truncated_normal, identity_init],
T in (Float16, Float32,
Float64, ComplexF16, ComplexF32, ComplexF64)

init === truncated_normal && !(T <: Real) && continue

@test init(T, 3) isa AbstractArray{T, 1}
@test init(rng, T, 3) isa arrtype{T, 1}
@test init(T, 3, 5) isa AbstractArray{T, 2}
@test init(rng, T, 3, 5) isa arrtype{T, 2}

cl = init(rng)
@test cl(T, 3) isa arrtype{T, 1}
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = init(rng, T)
@test cl(3) isa arrtype{T, 1}
@test cl(3, 5) isa arrtype{T, 2}
end

@testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal,
glorot_uniform, glorot_normal, truncated_normal, identity_init]
cl = init(;)
# Sizes
@test size(cl(3)) == (3,)
@test size(cl(rng, 3)) == (3,)
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
@test size(cl(3, 4, 5)) == (3, 4, 5)
@test size(cl(rng, 3, 4, 5)) == (3, 4, 5)
# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
end

@testset "Kwargs types" for T in (
Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
if (T <: Real)
@test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T
@test eltype(orthogonal(T, 2, 5; gain=1.0)) == T
end
@test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T
@test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T
@test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T
@test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T
@test eltype(identity_init(T, 2, 5; gain=1.0)) == T
@test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T
end

@testset "kaiming" begin
# kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)]
# and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = kaiming_uniform(rng, n_in, n_out)
σ2 = sqrt(6 / n_out)
@test -1σ2 < minimum(v) < -0.9σ2
@test 0.9σ2 < maximum(v) < 1σ2

v = kaiming_normal(rng, n_in, n_out)
σ2 = sqrt(2 / n_out)
@test 0.9σ2 < std(v) < 1.1σ2
end
# Type
@test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32
@test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32
end

@testset "glorot: $init" for init in [glorot_uniform, glorot_normal]
# glorot_uniform and glorot_normal should both yield a kernel with
# variance ≈ 2/(fan_in + fan_out)
for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)]
v = init(dims...)
fan_in, fan_out = WeightInitializers._nfan(dims...)
σ2 = 2 / (fan_in + fan_out)
@test 0.9σ2 < var(v) < 1.1σ2
end
@test eltype(init(3, 4; gain=1.5)) == Float32
end

@testset "orthogonal" begin
# A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition.
for (rows, cols) in [(5, 3), (3, 5)]
v = orthogonal(rows, cols)
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
end
for mat in [(3, 4, 5), (2, 2, 5)]
v = orthogonal(mat...)
cols = mat[end]
rows = div(prod(mat), cols)
v = reshape(v, (rows, cols))
rows < cols ? (@test v * v' I(rows)) : (@test v' * v I(cols))
end
@test eltype(orthogonal(3, 4; gain=1.5)) == Float32
end
end

@testset "Orthogonal rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes
# A matrix of dim = (m,n) with m > n should produce a QR decomposition.
# In the other case, the transpose should be taken to compute the QR decomposition.
for (rows, cols) in [(5, 3), (3, 5)]
v = orthogonal(rng, rows, cols)
CUDA.@allowscalar rows < cols ? (@test v * v' I(rows)) :
(@test v' * v I(cols))
end
for mat in [(3, 4, 5), (2, 2, 5)]
v = orthogonal(rng, mat...)
cols = mat[end]
rows = div(prod(mat), cols)
v = reshape(v, (rows, cols))
CUDA.@allowscalar rows < cols ? (@test v * v' I(rows)) :
(@test v' * v I(cols))
end
# Type
@testset "Orthogonal Types $T" for T in (Float32, Float64)#(Float16, Float32, Float64)
@test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T
@test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T
end
@testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64)#(Float16, Float32, Float64)
@test orthogonal(T, 3, 5) isa AbstractArray{T, 2}
@test orthogonal(rng, T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = orthogonal(rng, T)
@test cl(3, 5) isa arrtype{T, 2}
end
@testset "Orthogonal Closure" begin
cl = orthogonal(;)
# Sizes
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
@test size(cl(3, 4, 5)) == (3, 4, 5)
@test size(cl(rng, 3, 4, 5)) == (3, 4, 5)
# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
end
end

@testset "sparse_init rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes
# sparse_init should yield an error for non 2-d dimensions
# sparse_init should yield no zero elements if sparsity < 0
# sparse_init should yield all zero elements if sparsity > 1
# sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for other sparsity values
# sparse_init should yield a kernel in its non-zero elements consistent with the std parameter

@test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1)
@test_throws ArgumentError sparse_init(3, sparsity=0.1)
v = sparse_init(100, 100; sparsity=-0.1)
@test sum(v .== 0) == 0
v = sparse_init(100, 100; sparsity=1.1)
@test sum(v .== 0) == length(v)

for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)]
expected_zeros = ceil(Integer, n_in * sparsity)
v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ)
@test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out])
@test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ
end

# Type
@testset "sparse_init Types $T" for T in (Float16, Float32, Float64)
@test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T
end
@testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64)
@test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2}
@test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2}

cl = sparse_init(rng; sparsity=0.5)
@test cl(T, 3, 5) isa arrtype{T, 2}

cl = sparse_init(rng, T; sparsity=0.5)
@test cl(3, 5) isa arrtype{T, 2}
end
@testset "sparse_init Closure" begin
cl = sparse_init(; sparsity=0.5)
# Sizes
@test size(cl(3, 4)) == (3, 4)
@test size(cl(rng, 3, 4)) == (3, 4)
# Type
@test eltype(cl(4, 2)) == Float32
@test eltype(cl(rng, 4, 2)) == Float32
end
end

@testset "identity_init" begin
@testset "Non-identity sizes" begin
@test identity_init(2, 3)[:, end] == zeros(Float32, 2)
@test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2)
@test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3)
@test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3)
@test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3)
end
end
if !isempty(EXTRA_PKGS)
@info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS
Pkg.add(EXTRA_PKGS)
Pkg.update()
Base.retry_load_extensions()
Pkg.instantiate()
end

@testset "Warning: truncated_normal" begin
@test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \
the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0)
end
const RETESTITEMS_NWORKERS = parse(
Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 4))))
const RETESTITEMS_NWORKER_THREADS = parse(Int,
get(ENV, "RETESTITEMS_NWORKER_THREADS",
string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1))))

@testset "Aqua: Quality Assurance" begin
Aqua.test_all(WeightInitializers; ambiguities=false)
Aqua.test_ambiguities(WeightInitializers; recursive=false)
end
end
ReTestItems.runtests(WeightInitializers; nworkers=RETESTITEMS_NWORKERS,
nworker_threads=RETESTITEMS_NWORKER_THREADS)
43 changes: 43 additions & 0 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@testsetup module SharedTestSetup

using GPUArrays, GPUArraysCore, Random, StableRNGs

GPUArraysCore.allowscalar(false)

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))

RNGS_ARRTYPES = []
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
append!(RNGS_ARRTYPES,
[(StableRNG(12345), AbstractArray, true, "cpu"),
(Random.GLOBAL_RNG, AbstractArray, true, "cpu")])
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda"
using CUDA
append!(RNGS_ARRTYPES,
[(CUDA.default_rng(), CuArray, true, "cuda"),
(GPUArrays.default_rng(CuArray), CuArray, true, "cuda")])
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu"
using AMDGPU
append!(RNGS_ARRTYPES,
[(AMDGPU.rocrand_rng(), ROCArray, true, "amdgpu"),
(AMDGPU.gpuarrays_rng(), ROCArray, true, "amdgpu")])
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal"
using Metal
push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray, false, "metal"))
end
if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi"
using oneAPI
using oneAPI: oneL0

supports_fp64 = oneL0.module_properties(first(oneAPI.devices())).fp64flags &
oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64

push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64, "oneapi"))
end

export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP, GPUArrays

end
9 changes: 9 additions & 0 deletions test/utils_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testitem "Utils.nfan" begin
using WeightInitializers: Utils

@test Utils.nfan() == (1, 1) # Fallback
@test Utils.nfan(4) == (1, 4) # Vector
@test Utils.nfan(4, 5) == (5, 4) # Matrix
@test Utils.nfan((4, 5, 6)) == Utils.nfan(4, 5, 6) # Tuple
@test Utils.nfan(4, 5, 6) == 4 .* (5, 6) # Convolution
end