From df82ba0e6bdc7ecc5ff8a3bc580b741168038b72 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 21 Aug 2024 08:23:33 -0700 Subject: [PATCH] refactor: move ChainRulesCore into an extension --- Project.toml | 5 +++-- ext/WeightInitializersChainRulesCoreExt.jl | 18 ++++++++++++++++++ src/WeightInitializers.jl | 10 ---------- src/utils.jl | 5 ----- 4 files changed, 21 insertions(+), 17 deletions(-) create mode 100644 ext/WeightInitializersChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index b01313d..308235c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,10 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.2" +version = "1.0.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -16,6 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" @@ -23,6 +23,7 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"] WeightInitializersCUDAExt = ["CUDA", "GPUArrays"] +WeightInitializersChainRulesCoreExt = "ChainRulesCore" WeightInitializersGPUArraysExt = "GPUArrays" WeightInitializersMetalExt = ["Metal", "GPUArrays"] WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] diff --git a/ext/WeightInitializersChainRulesCoreExt.jl b/ext/WeightInitializersChainRulesCoreExt.jl new file mode 100644 index 0000000..2b54893 --- /dev/null +++ b/ext/WeightInitializersChainRulesCoreExt.jl @@ -0,0 +1,18 @@ +module WeightInitializersChainRulesCoreExt + +using ChainRulesCore: @non_differentiable +using WeightInitializers: WeightInitializers, DeviceAgnostic + +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval @non_differentiable WeightInitializers.$(f)(::Any...) +end + +for f in (:zeros, :ones, :rand, :randn) + @eval @non_differentiable DeviceAgnostic.$(f)(::Any...) +end + +end diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index e96eebb..6702f3f 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,7 +1,6 @@ module WeightInitializers using ArgCheck: @argcheck -using ChainRulesCore: @non_differentiable using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using Random: Random, AbstractRNG, shuffle @@ -12,15 +11,6 @@ include("partial.jl") include("utils.jl") include("initializers.jl") -# Mark the functions as non-differentiable -for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, - :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, - :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, - :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, - :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] - @eval @non_differentiable $(f)(::Any...) -end - export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16, diff --git a/src/utils.jl b/src/utils.jl index 201283d..e2a3a36 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -52,7 +52,6 @@ end module DeviceAgnostic -using ChainRulesCore: @non_differentiable using Random: AbstractRNG # Helpers for device agnostic initializers @@ -76,8 +75,4 @@ for f in (:rand, :randn) end end -for f in (:zeros, :ones, :rand, :randn) - @eval @non_differentiable $f(::Any...) -end - end