diff --git a/Project.toml b/Project.toml index 70c04423..d8418a9d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.51" +version = "1.0.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -67,9 +67,9 @@ Hwloc = "3.2" KernelAbstractions = "0.9.22" LinearAlgebra = "1.10" LoopVectorization = "0.12.171" -LuxCore = "0.1.13, 1" +LuxCore = "1" MKL = "0.7" -MLDataDevices = "1.0.0" +MLDataDevices = "1" Markdown = "1.10" NNlib = "0.9.21" Octavian = "0.3.28" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index e6436756..7fe762e6 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -3,6 +3,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/benchmarks/setup.jl b/benchmarks/setup.jl index f80ccf4b..06211e9d 100644 --- a/benchmarks/setup.jl +++ b/benchmarks/setup.jl @@ -1,4 +1,5 @@ using MLDataDevices, StableRNGs, Random +using NNlib using Zygote synchronize(::CPUDevice) = nothing diff --git a/src/LuxLib.jl b/src/LuxLib.jl index c1f3c00a..ab79b233 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -1,7 +1,6 @@ module LuxLib using Compat: @compat -using Random: AbstractRNG using Reexport: @reexport using Static: Static, known using UnrolledUtilities: unrolled_filter @@ -10,9 +9,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent using LuxCore: LuxCore using MLDataDevices: get_device_type, AbstractGPUDevice -using NNlib: NNlib, ConvDims, σ - -@reexport using NNlib +using NNlib: NNlib const Optional{T} = Union{Nothing, T} const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} @@ -23,7 +20,6 @@ include("utils.jl") include("traits.jl") include("impl/Impl.jl") include("api/API.jl") -include("deprecations.jl") @compat(public, (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) diff --git a/src/api/layernorm.jl b/src/api/layernorm.jl index c374a6e1..eb147d30 100644 --- a/src/api/layernorm.jl +++ b/src/api/layernorm.jl @@ -1,6 +1,6 @@ @doc doc""" - layernorm(x, scale, bias, σ = identity, dims=Colon(), - epsilon = eps(eltype(x)) ^ (5 / 7)) + layernorm(x::AbstractArray{xT, N}, scale, bias, σ = identity, dims=1:(N - 1), + epsilon = eps(eltype(x)) ^ (5 / 7)) where {xT, N} Layer Normalization. For details see [1]. @@ -18,17 +18,13 @@ and applies the activation function `σ` elementwise to `y`. - `scale`: Scale factor (``\gamma``) (can be `nothing`) - `bias`: Bias factor (``\beta``) (can be `nothing`) - `σ`: Activation function (default: `identity`) - - `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`). - If `nothing` is passed, the dims are inferred based on the dimensions of scale and - bias. For example, if `x` is `N` dimensional and `scale` and `bias` are `M` - dimensional, then the dims will be `1:(N - M)`. + - `dims`: Dimensions along which the mean and std of `x` is computed. If `nothing` is + passed, the dims are inferred based on the dimensions of scale and bias. For example, + if `x` is `N` dimensional and `scale` and `bias` are `M` dimensional, then the dims + will be `1:(N - M)`. - `epsilon`: Value added to the denominator for numerical stability (default: `eps(eltype(x)) ^ (5 / 7)`) -!!! danger "Default `dims` to be changed in v1" - - By default, `dims` will exclude the batch dimension. - ## Returns Normalized Array of same size as `x`. @@ -38,9 +34,9 @@ Normalized Array of same size as `x`. [1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ -function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray}, - bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(), - epsilon::Real=default_epsilon(x)) where {F, xT} +function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1), + epsilon::Real=default_epsilon(x)) where {F, xT, N} return layernorm_impl( x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) end diff --git a/src/impl/Impl.jl b/src/impl/Impl.jl index fd2a128e..7a040456 100644 --- a/src/impl/Impl.jl +++ b/src/impl/Impl.jl @@ -29,7 +29,7 @@ using NNlib: NNlib, ConvDims using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous, - copy_drop_gradients, depwarn, eltype_mismatch, expand_batchdim, + copy_drop_gradients, eltype_mismatch, expand_batchdim, maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, unsafe_known, @enzyme_alternative diff --git a/src/impl/dropout.jl b/src/impl/dropout.jl index 473b6a35..320eafbc 100644 --- a/src/impl/dropout.jl +++ b/src/impl/dropout.jl @@ -13,16 +13,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, end function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, - p::T, ::True, ::False, invp::T, dims) where {T} - if dropout_shape(x, dims) != size(mask) - depwarn( - "`update_mask` is `Val(false)` but `mask` is not of the same size \ - as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \ - will be removed in the next release. Set `update_mask` to \ - `Val(true)` to avoid this.", :dropout) - mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) - return dropout_dot_mul(x, mask), mask, rngₙ - end + ::T, ::True, ::False, invp::T, dims) where {T} + check_dropout_mask_shape_mismatch(x, mask, dims) return dropout_dot_mul(x, mask), mask, rng end @@ -31,6 +23,13 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, return (x, mask, rng) end +function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims) + @assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`." + return nothing +end + +CRC.@non_differentiable check_dropout_mask_shape_mismatch(::Any...) + ## alpha_dropout function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T} α = T(-1.7580993408473766) diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index f3989f49..08b431ba 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -102,7 +102,7 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays + using StaticArrays, NNlib x = @SArray rand(2, 4) weight = @SArray rand(3, 2) @@ -112,7 +112,7 @@ end end @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays + using JLArrays, NNlib x = JLArray(rand(Float32, 2, 4)) weight = JLArray(rand(Float32, 3, 2)) diff --git a/test/common_ops/dropout_tests.jl b/test/common_ops/dropout_tests.jl index e8b637df..f7f2368b 100644 --- a/test/common_ops/dropout_tests.jl +++ b/test/common_ops/dropout_tests.jl @@ -42,8 +42,6 @@ end @testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin - Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation - using Statistics rng = StableRNG(12345) @@ -100,8 +98,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime values - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any __f = let rng = rng, mask = mask x -> sum(first(dropout( @@ -115,35 +112,6 @@ end rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType - # Try using mask if possible (not possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any - - y, mask_, rng_ = dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()) - - @test y isa aType{T, length(x_shape)} - @test size(y) == x_shape - @test mask_ isa aType{T, length(x_shape)} - @test size(mask_) == x_shape - @test rng != rng_ - @test mask != mask_ - - __f = (x, mask) -> sum(first(dropout( - StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon()))) - # Branching based on runtime activity - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true - - __f = let rng = rng, mask = mask - x -> sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) - end - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, - soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), - broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) - - @jet sum(first(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon()))) # Testing Mode @test @inferred(dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any diff --git a/test/others/qa_tests.jl b/test/others/qa_tests.jl index 7875b52f..ed7e9f98 100644 --- a/test/others/qa_tests.jl +++ b/test/others/qa_tests.jl @@ -1,5 +1,5 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin - using Aqua, ChainRulesCore, EnzymeCore + using Aqua, ChainRulesCore, EnzymeCore, NNlib using EnzymeCore: EnzymeRules Aqua.test_all( diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 6088d444..4cf27cfb 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices -@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib LuxTestUtils.jet_target_modules!(["LuxLib"])