Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Weight Initialization Schemes for Deep Learning Frameworks

License

Notifications You must be signed in to change notification settings

LuxDL/WeightInitializers.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WeightInitializers

Join the chat at https://julialang.zulipchat.com #machine-learning Latest Docs Stable Docs Aqua QA

Build status CI codecov

ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

Warning

Package has been moved to a subdirectory in Lux https://github.com/LuxDL/Lux.jl/tree/main/lib/

This package is a light dependency providing common weight initialization schemes for deep learning models.

Example

These code snippets are just provided to give a high level overview of the functionalities of the package.

using WeightInitializers, Random

# Fixing rng
rng = MersenneTwister(42)

# Explicit rng call
weights = kaiming_normal(rng, 2, 5)
#2×5 Matrix{Float32}:
# -0.351662   0.0171745   1.12442   -0.296372   -1.67094
# -0.281053  -0.18941    -0.724099   0.0987538   0.634549

# Default rng call
weights = kaiming_normal(2, 5)
#2×5 Matrix{Float32}:
# -0.227513  -0.265372   0.265788  1.29955  -0.192836
#  0.687611   0.454679  -0.433656  0.20548   0.292002

# Passing kwargs (if needed) with explicit rng call
weights_cl = kaiming_normal(rng; gain=1.0)
weights = weights_cl(rng, 2, 5)
#2×5 Matrix{Float32}:
# 0.484056   0.231723   0.164379   0.306147   0.18365
# 0.0836414  0.666965  -0.396323  -0.711329  -0.382971

# Passing kwargs (if needed) with default rng call
weights_cl = kaiming_normal(; gain=1.0)
weights = weights_cl(2, 5)
#2×5 Matrix{Float32}:
# -0.160876  -0.187646   0.18794   0.918918  -0.136356
#  0.486214   0.321506  -0.306641  0.145296   0.206476

API

The package is meant to be working with deep learning libraries such as F/Lux. All the methods take as input the chosen rng type and the dimension for the AbstractArray.

weights = init(rng, dims...)

The rng is optional, if not specified a default one will be used.

weights = init(dims...)

If there is the need to use keyword arguments the methods can be called with just the rng (optionally) and the keywords to get in return a function behaving like the two examples above.

weights_init = init(rng; kwargs...)
weights = weights_init(rng, dims...)
# or
weights_init = init(; kwargs...)
weights = weights_init(dims...)

About

Weight Initialization Schemes for Deep Learning Frameworks

Resources

License

Stars

Watchers

Forks

Packages

No packages published