Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move ChainRulesCore into an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 21, 2024
1 parent 0d95f4a commit df82ba0
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.0.2"
version = "1.0.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -16,13 +15,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"]
WeightInitializersCUDAExt = ["CUDA", "GPUArrays"]
WeightInitializersChainRulesCoreExt = "ChainRulesCore"
WeightInitializersGPUArraysExt = "GPUArrays"
WeightInitializersMetalExt = ["Metal", "GPUArrays"]
WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]
Expand Down
18 changes: 18 additions & 0 deletions ext/WeightInitializersChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module WeightInitializersChainRulesCoreExt

using ChainRulesCore: @non_differentiable
using WeightInitializers: WeightInitializers, DeviceAgnostic

for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval @non_differentiable WeightInitializers.$(f)(::Any...)
end

for f in (:zeros, :ones, :rand, :randn)
@eval @non_differentiable DeviceAgnostic.$(f)(::Any...)
end

end
10 changes: 0 additions & 10 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module WeightInitializers

using ArgCheck: @argcheck
using ChainRulesCore: @non_differentiable
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using Random: Random, AbstractRNG, shuffle
Expand All @@ -12,15 +11,6 @@ include("partial.jl")
include("utils.jl")
include("initializers.jl")

# Mark the functions as non-differentiable
for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval @non_differentiable $(f)(::Any...)
end

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16,
Expand Down
5 changes: 0 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ end

module DeviceAgnostic

using ChainRulesCore: @non_differentiable
using Random: AbstractRNG

# Helpers for device agnostic initializers
Expand All @@ -76,8 +75,4 @@ for f in (:rand, :randn)
end
end

for f in (:zeros, :ones, :rand, :randn)
@eval @non_differentiable $f(::Any...)
end

end

0 comments on commit df82ba0

Please sign in to comment.