diff --git a/Project.toml b/Project.toml index 361b329..a71f74f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -19,6 +20,7 @@ WeightInitializersCUDAExt = "CUDA" [compat] Aqua = "0.8" CUDA = "5" +ChainRulesCore = "1.21" PartialFunctions = "1.2" PrecompileTools = "1.2" Random = "1.9" diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 4a33516..446fa8f 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -3,12 +3,21 @@ module WeightInitializers import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using PartialFunctions, Random, SpecialFunctions, Statistics + using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics end include("utils.jl") include("initializers.jl") +# Mark the functions as non-differentiable +for f in [ + :zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, :zeros16, + :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, :randnC64, :zerosC32, + :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, :randC16, :randnC16, :glorot_normal, + :glorot_uniform, :kaiming_normal, :kaiming_uniform, :truncated_normal] + @eval @non_differentiable $(f)(::Any...) +end + export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16,