Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Mark the initialization functions as non-differentiable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 13, 2024
1 parent bde282c commit 0ea3234
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.1.4"
version = "0.1.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,6 +20,7 @@ WeightInitializersCUDAExt = "CUDA"
[compat]
Aqua = "0.8"
CUDA = "5"
ChainRulesCore = "1.21"
PartialFunctions = "1.2"
PrecompileTools = "1.2"
Random = "1.9"
Expand Down
13 changes: 12 additions & 1 deletion src/WeightInitializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,23 @@ module WeightInitializers
import PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using PartialFunctions, Random, SpecialFunctions, Statistics
using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics
end

const CRC = ChainRulesCore

include("utils.jl")
include("initializers.jl")

# Mark the functions as non-differentiable
for f in [
:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, :zeros16,
:ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, :randnC64, :zerosC32,
:onesC32, :randC32, :randnC32, :zerosC16, :onesC16, :randC16, :randnC16, :glorot_normal,
:glorot_uniform, :kaiming_normal, :kaiming_uniform, :truncated_normal]
@eval CRC.@non_differentiable ($(f)(::Any...))
end

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16,
Expand Down

0 comments on commit 0ea3234

Please sign in to comment.