From bee1411b878e1ff2436da7b1385cc7dbcaa95014 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 7 Feb 2024 07:50:14 -0500 Subject: [PATCH] Remove extensions in favor of GPUArraysCore --- Project.toml | 24 +++++------------------- ext/LuxLuxAMDGPUExt.jl | 25 +------------------------ ext/LuxLuxCUDAExt.jl | 22 ---------------------- ext/LuxMetalExt.jl | 23 ----------------------- src/Lux.jl | 1 + src/utils.jl | 19 ++++++++++++++++--- test/test_utils.jl | 8 ++++---- test/utils.jl | 4 ++-- 8 files changed, 29 insertions(+), 97 deletions(-) delete mode 100644 ext/LuxLuxCUDAExt.jl delete mode 100644 ext/LuxMetalExt.jl diff --git a/Project.toml b/Project.toml index 6467eb6135..fde07e43bc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,17 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.15" +version = "0.5.16" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" @@ -32,8 +34,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -44,8 +44,6 @@ LuxComponentArraysExt = "ComponentArrays" LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"] LuxFluxTransformExt = "Flux" LuxLuxAMDGPUExt = "LuxAMDGPU" -LuxLuxCUDAExt = "LuxCUDA" -LuxMetalExt = "Metal" LuxReverseDiffExt = "ReverseDiff" LuxTrackerExt = "Tracker" LuxZygoteExt = "Zygote" @@ -53,6 +51,7 @@ LuxZygoteExt = "Zygote" [compat] ADTypes = "0.1, 0.2" Adapt = "3, 4" +ArrayInterface = "7" ChainRules = "1" ChainRulesCore = "1" ComponentArrays = "0.15.2" @@ -61,15 +60,14 @@ ConstructionBase = "1.5" FillArrays = "0.13, 1" Flux = "0.13, 0.14" Functors = "0.2, 0.3, 0.4" +GPUArraysCore = "0.1" LinearAlgebra = "1" LuxAMDGPU = "0.1, 0.2" -LuxCUDA = "0.2, 0.3" LuxCore = "0.1.6" LuxDeviceUtils = "0.1" LuxLib = "0.3" MacroTools = "0.5" Markdown = "1" -Metal = "0.5, 1" Optimisers = "0.2, 0.3" PrecompileTools = "1" Random = "1" @@ -83,15 +81,3 @@ TruncatedStacktraces = "1.1" WeightInitializers = "0.1" Zygote = "0.6" julia = "1.9" - -[extras] -ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/LuxLuxAMDGPUExt.jl b/ext/LuxLuxAMDGPUExt.jl index 2f9833b9f0..684772fcc5 100644 --- a/ext/LuxLuxAMDGPUExt.jl +++ b/ext/LuxLuxAMDGPUExt.jl @@ -1,29 +1,6 @@ module LuxLuxAMDGPUExt -using ChainRulesCore, Lux, LuxAMDGPU, LuxLib, Random -import ChainRulesCore as CRC - -# utils.jl -Lux.replicate(rng::AMDGPU.rocRAND.RNG) = deepcopy(rng) - -@inline function Lux._init_hidden_state(rng::AbstractRNG, rnn, x::AMDGPU.AnyROCArray) - return ROCArray(rnn.init_state(rng, rnn.out_dims, size(x, 2))) -end - -@inline function Lux._conv(x::SubArray{T, N, <:AMDGPU.AnyROCArray}, weight, - cdims) where {T, N} - return conv(copy(x), weight, cdims) -end - -@inline function Lux._conv_transpose(x::SubArray{T, N, <:AMDGPU.AnyROCArray}, weight, - cdims) where {T, N} - return ∇conv_data(copy(x), weight, cdims) -end - -@inline function Lux._eachslice(x::AMDGPU.AnyROCArray, ::Val{dims}) where {dims} - # FIXME: This is not efficient but AMDGPU doesn't deal with views well - return [copy(selectdim(x, dims, i)) for i in axes(x, dims)] -end +using Lux, LuxAMDGPU # Flux modifies Conv weights while mapping to AMD GPU function Lux._maybe_flip_conv_weight(x::AMDGPU.AnyROCArray) diff --git a/ext/LuxLuxCUDAExt.jl b/ext/LuxLuxCUDAExt.jl deleted file mode 100644 index 14388004e3..0000000000 --- a/ext/LuxLuxCUDAExt.jl +++ /dev/null @@ -1,22 +0,0 @@ -module LuxLuxCUDAExt - -using ChainRulesCore, Lux, LuxCUDA, LuxLib, Random -import ChainRulesCore as CRC - -# utils.jl -Lux.replicate(rng::CUDA.RNG) = deepcopy(rng) - -@inline function Lux._init_hidden_state(rng::AbstractRNG, rnn, x::CUDA.AnyCuArray) - return CuArray(rnn.init_state(rng, rnn.out_dims, size(x, 2))) -end - -@inline function Lux._conv(x::SubArray{T, N, <:CUDA.AnyCuArray}, weight, cdims) where {T, N} - return conv(copy(x), weight, cdims) -end - -@inline function Lux._conv_transpose(x::SubArray{T, N, <:CUDA.AnyCuArray}, weight, - cdims) where {T, N} - return ∇conv_data(copy(x), weight, cdims) -end - -end diff --git a/ext/LuxMetalExt.jl b/ext/LuxMetalExt.jl deleted file mode 100644 index fb97682378..0000000000 --- a/ext/LuxMetalExt.jl +++ /dev/null @@ -1,23 +0,0 @@ -module LuxMetalExt - -using Lux, LuxLib, Metal, Random - -@inline function Lux._init_hidden_state(rng::AbstractRNG, rnn, x::MtlArray) - return MtlArray(rnn.init_state(rng, rnn.out_dims, size(x, 2))) -end - -@inline function Lux._conv(x::SubArray{T, N, <:MtlArray}, weight, cdims) where {T, N} - return conv(copy(x), weight, cdims) -end - -@inline function Lux._conv_transpose(x::SubArray{T, N, <:MtlArray}, weight, - cdims) where {T, N} - return ∇conv_data(copy(x), weight, cdims) -end - -@inline function Lux._eachslice(x::MtlArray, ::Val{dims}) where {dims} - # FIXME: This is not efficient but Metal doesn't deal with views well - return [copy(selectdim(x, dims, i)) for i in axes(x, dims)] -end - -end diff --git a/src/Lux.jl b/src/Lux.jl index 69cc4e1325..a0bb7bcbaa 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -8,6 +8,7 @@ PrecompileTools.@recompile_invalidations begin using LinearAlgebra, Markdown, Random, SparseArrays, Statistics using Adapt, ConcreteStructs, Functors, Setfield using ChainRulesCore + using ArrayInterface, GPUArraysCore import TruncatedStacktraces: @truncate_stacktrace import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, diff --git a/src/utils.jl b/src/utils.jl index fc7511f66b..e30b1f37f6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,11 +1,10 @@ # PRNG Handling """ replicate(rng::AbstractRNG) - replicate(rng::CUDA.RNG) Creates a copy of the `rng` state depending on its type. """ -replicate(rng::AbstractRNG) = copy(rng) +replicate(rng::AbstractRNG) = deepcopy(rng) # Training Check """ @@ -61,7 +60,8 @@ get_typename(::T) where {T} = Base.typename(T).wrapper @inline _gate(x::AbstractMatrix, h::Int, n::Int) = view(x, _gate(h, n), :) @inline function _init_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) - return rnn.init_state(rng, rnn.out_dims, size(x, 2)) + return convert(ArrayInterface.parameterless_type(parent(x)), + rnn.init_state(rng, rnn.out_dims, size(x, 2))) end @inline function _init_trainable_hidden_state(hidden_state::AbstractVector, @@ -99,6 +99,12 @@ end @inline function _eachslice(x::AbstractArray, ::Val{dims}) where {dims} return [selectdim(x, dims, i) for i in axes(x, dims)] end +@inline function _eachslice(x::GPUArraysCore.AnyGPUArray, ::Val{dims}) where {dims} + return [__unview(selectdim(x, dims, i)) for i in axes(x, dims)] +end + +@inline __unview(x::SubArray) = copy(x) +@inline __unview(x) = x function ∇_eachslice(Δ_raw, x::AbstractArray, ::Val{dims}) where {dims} Δs = CRC.unthunk(Δ_raw) @@ -123,8 +129,15 @@ end # Backend Integration ## Convolution @inline _conv(x, weight, cdims) = conv(x, weight, cdims) +@inline function _conv(x::SubArray{T, N, <:AbstractArray}, weight, cdims) where {T, N} + return _conv(copy(x), weight, cdims) +end @inline _conv_transpose(x, weight, cdims) = ∇conv_data(x, weight, cdims) +@inline function _conv_transpose(x::SubArray{T, N, <:GPUArraysCore.AnyGPUArray}, weight, + cdims) where {T, N} + return _conv_transpose(copy(x), weight, cdims) +end function _conv_transpose_dims(x::AbstractArray, weight::AbstractArray; padding, stride, dilation, groups) diff --git a/test/test_utils.jl b/test/test_utils.jl index 164d0eb240..b2e80596a4 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -27,11 +27,11 @@ end # Some Helper Functions function get_default_rng(mode::String) if mode == "CPU" - return Random.default_rng() + return deepcopy(default_device_rng(LuxCPUDevice())) elseif mode == "CUDA" - return CUDA.RNG() + return deepcopy(default_device_rng(LuxCUDADevice())) elseif mode == "AMDGPU" - return AMDGPU.rocRAND.RNG() + return deepcopy(default_device_rng(LuxAMDGPUDevice())) else error("Unknown mode: $mode") end @@ -42,7 +42,7 @@ get_stable_rng(seed=12345) = StableRNG(seed) # AMDGPU Specifics function _rocRAND_functional() try - AMDGPU.rocRAND.RNG() + default_device_rng(LuxAMDGPUDevice()) return true catch return false diff --git a/test/utils.jl b/test/utils.jl index e6b6acd200..9d528e4be4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,8 +8,8 @@ rng = get_stable_rng(12345) @testset "$mode: replicate" for (mode, aType, device, ongpu) in MODES _rng = get_default_rng(mode) if mode == "AMDGPU" - @test randn(_rng, 10, 2) != randn(_rng, 10, 2) - @test_broken randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2) + # @test randn(_rng, 10, 2) != randn(_rng, 10, 2) + # @test_broken randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2) else @test randn(_rng, 10, 2) != randn(_rng, 10, 2) @test randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2)