Skip to content

Commit

Permalink
Remove extensions in favor of GPUArraysCore
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 7, 2024
1 parent 792e10d commit bee1411
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 97 deletions.
24 changes: 5 additions & 19 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.15"
version = "0.5.16"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Expand All @@ -32,8 +34,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -44,15 +44,14 @@ LuxComponentArraysExt = "ComponentArrays"
LuxComponentArraysReverseDiffExt = ["ComponentArrays", "ReverseDiff"]
LuxFluxTransformExt = "Flux"
LuxLuxAMDGPUExt = "LuxAMDGPU"
LuxLuxCUDAExt = "LuxCUDA"
LuxMetalExt = "Metal"
LuxReverseDiffExt = "ReverseDiff"
LuxTrackerExt = "Tracker"
LuxZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2"
Adapt = "3, 4"
ArrayInterface = "7"
ChainRules = "1"
ChainRulesCore = "1"
ComponentArrays = "0.15.2"
Expand All @@ -61,15 +60,14 @@ ConstructionBase = "1.5"
FillArrays = "0.13, 1"
Flux = "0.13, 0.14"
Functors = "0.2, 0.3, 0.4"
GPUArraysCore = "0.1"
LinearAlgebra = "1"
LuxAMDGPU = "0.1, 0.2"
LuxCUDA = "0.2, 0.3"
LuxCore = "0.1.6"
LuxDeviceUtils = "0.1"
LuxLib = "0.3"
MacroTools = "0.5"
Markdown = "1"
Metal = "0.5, 1"
Optimisers = "0.2, 0.3"
PrecompileTools = "1"
Random = "1"
Expand All @@ -83,15 +81,3 @@ TruncatedStacktraces = "1.1"
WeightInitializers = "0.1"
Zygote = "0.6"
julia = "1.9"

[extras]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
25 changes: 1 addition & 24 deletions ext/LuxLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,6 @@
module LuxLuxAMDGPUExt

using ChainRulesCore, Lux, LuxAMDGPU, LuxLib, Random
import ChainRulesCore as CRC

# utils.jl
Lux.replicate(rng::AMDGPU.rocRAND.RNG) = deepcopy(rng)

@inline function Lux._init_hidden_state(rng::AbstractRNG, rnn, x::AMDGPU.AnyROCArray)
return ROCArray(rnn.init_state(rng, rnn.out_dims, size(x, 2)))
end

@inline function Lux._conv(x::SubArray{T, N, <:AMDGPU.AnyROCArray}, weight,
cdims) where {T, N}
return conv(copy(x), weight, cdims)
end

@inline function Lux._conv_transpose(x::SubArray{T, N, <:AMDGPU.AnyROCArray}, weight,
cdims) where {T, N}
return ∇conv_data(copy(x), weight, cdims)
end

@inline function Lux._eachslice(x::AMDGPU.AnyROCArray, ::Val{dims}) where {dims}
# FIXME: This is not efficient but AMDGPU doesn't deal with views well
return [copy(selectdim(x, dims, i)) for i in axes(x, dims)]
end
using Lux, LuxAMDGPU

# Flux modifies Conv weights while mapping to AMD GPU
function Lux._maybe_flip_conv_weight(x::AMDGPU.AnyROCArray)
Expand Down
22 changes: 0 additions & 22 deletions ext/LuxLuxCUDAExt.jl

This file was deleted.

23 changes: 0 additions & 23 deletions ext/LuxMetalExt.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ PrecompileTools.@recompile_invalidations begin
using LinearAlgebra, Markdown, Random, SparseArrays, Statistics
using Adapt, ConcreteStructs, Functors, Setfield
using ChainRulesCore
using ArrayInterface, GPUArraysCore
import TruncatedStacktraces: @truncate_stacktrace

import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer,
Expand Down
19 changes: 16 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# PRNG Handling
"""
replicate(rng::AbstractRNG)
replicate(rng::CUDA.RNG)
Creates a copy of the `rng` state depending on its type.
"""
replicate(rng::AbstractRNG) = copy(rng)
replicate(rng::AbstractRNG) = deepcopy(rng)

# Training Check
"""
Expand Down Expand Up @@ -61,7 +60,8 @@ get_typename(::T) where {T} = Base.typename(T).wrapper
@inline _gate(x::AbstractMatrix, h::Int, n::Int) = view(x, _gate(h, n), :)

@inline function _init_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix)
return rnn.init_state(rng, rnn.out_dims, size(x, 2))
return convert(ArrayInterface.parameterless_type(parent(x)),
rnn.init_state(rng, rnn.out_dims, size(x, 2)))
end

@inline function _init_trainable_hidden_state(hidden_state::AbstractVector,
Expand Down Expand Up @@ -99,6 +99,12 @@ end
@inline function _eachslice(x::AbstractArray, ::Val{dims}) where {dims}
return [selectdim(x, dims, i) for i in axes(x, dims)]
end
@inline function _eachslice(x::GPUArraysCore.AnyGPUArray, ::Val{dims}) where {dims}
return [__unview(selectdim(x, dims, i)) for i in axes(x, dims)]
end

@inline __unview(x::SubArray) = copy(x)
@inline __unview(x) = x

function ∇_eachslice(Δ_raw, x::AbstractArray, ::Val{dims}) where {dims}
Δs = CRC.unthunk(Δ_raw)
Expand All @@ -123,8 +129,15 @@ end
# Backend Integration
## Convolution
@inline _conv(x, weight, cdims) = conv(x, weight, cdims)
@inline function _conv(x::SubArray{T, N, <:AbstractArray}, weight, cdims) where {T, N}
return _conv(copy(x), weight, cdims)
end

@inline _conv_transpose(x, weight, cdims) = ∇conv_data(x, weight, cdims)
@inline function _conv_transpose(x::SubArray{T, N, <:GPUArraysCore.AnyGPUArray}, weight,
cdims) where {T, N}
return _conv_transpose(copy(x), weight, cdims)
end

function _conv_transpose_dims(x::AbstractArray, weight::AbstractArray; padding, stride,
dilation, groups)
Expand Down
8 changes: 4 additions & 4 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ end
# Some Helper Functions
function get_default_rng(mode::String)
if mode == "CPU"
return Random.default_rng()
return deepcopy(default_device_rng(LuxCPUDevice()))
elseif mode == "CUDA"
return CUDA.RNG()
return deepcopy(default_device_rng(LuxCUDADevice()))
elseif mode == "AMDGPU"
return AMDGPU.rocRAND.RNG()
return deepcopy(default_device_rng(LuxAMDGPUDevice()))
else
error("Unknown mode: $mode")
end
Expand All @@ -42,7 +42,7 @@ get_stable_rng(seed=12345) = StableRNG(seed)
# AMDGPU Specifics
function _rocRAND_functional()
try
AMDGPU.rocRAND.RNG()
default_device_rng(LuxAMDGPUDevice())
return true
catch
return false
Expand Down
4 changes: 2 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ rng = get_stable_rng(12345)
@testset "$mode: replicate" for (mode, aType, device, ongpu) in MODES
_rng = get_default_rng(mode)
if mode == "AMDGPU"
@test randn(_rng, 10, 2) != randn(_rng, 10, 2)
@test_broken randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2)
# @test randn(_rng, 10, 2) != randn(_rng, 10, 2)
# @test_broken randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2)
else
@test randn(_rng, 10, 2) != randn(_rng, 10, 2)
@test randn(Lux.replicate(_rng), 10, 2) == randn(Lux.replicate(_rng), 10, 2)
Expand Down

0 comments on commit bee1411

Please sign in to comment.