diff --git a/Project.toml b/Project.toml index 52f436bbb..911aa5a2b 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" @@ -90,6 +91,7 @@ GPUArraysCore = "0.1.6" Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" +LogExpFunctions = "0.3.28" Logging = "1.10" LossFunctions = "0.11.1" LuxCore = "0.1.16" diff --git a/src/Lux.jl b/src/Lux.jl index 73cddcbc9..478d4b43e 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -15,6 +15,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using Functors: Functors, fmap using GPUArraysCore: GPUArraysCore, @allowscalar +using LogExpFunctions: LogExpFunctions using LossFunctions: LossFunctions using MacroTools: MacroTools, block, combinedef, splitdef using Markdown: @doc_str diff --git a/src/chainrules.jl b/src/chainrules.jl index 2e72c018a..ac3fa3a49 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -14,6 +14,7 @@ CRC.@non_differentiable __state_if_stateful(::Any) CRC.@non_differentiable __set_state!(::Any...) CRC.@non_differentiable __update_bn_state(::Any...) CRC.@non_differentiable __warn_mismatch(::Any...) +CRC.@non_differentiable __depwarn(::Any, ::Any) # Utilities function CRC.rrule(::typeof(_eachslice), x, d::Val) @@ -88,33 +89,3 @@ end end return __fused_agg(sum, lfn, x, y), ∇lfn end - -function CRC.rrule(::typeof(xlogx), x::Number) - iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) - logx = log(x) - ∇xlogx = @closure Δ -> (NoTangent(), @thunk(Δ*(logx + true))) - return x * logx, ∇xlogx -end - -function CRC.rrule( - ::typeof(Broadcast.broadcasted), ::typeof(xlogx), x::AbstractArray{<:Number}) - logx = log.(x) - y = x .* logx - ∇xlogx = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*(logx .+ true))) - return y, ∇xlogx -end - -function CRC.rrule(::typeof(xlogy), x::Number, y::Number) - iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) - logy = log(y) - ∇xlogy = @closure Δ -> (NoTangent(), @thunk(Δ*logy), @thunk(Δ * x/y)) - return x * logy, ∇xlogy -end - -function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy), - x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) - logy = log.(y) - z = x .* logy - ∇xlogy = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*logy), @thunk(Δ .* x./y)) - return z, ∇xlogy -end diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index ba56cc97c..60c8b603b 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -424,7 +424,8 @@ end function __unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y) cross_entropy = __unsafe_apply_loss(loss.celoss, ŷ, y) - entropy = loss.agg(sum(xlogx.(y); loss.dims)) # Intentional broadcasting for Zygote type stability + # Intentional broadcasting for Zygote type stability + entropy = loss.agg(sum(LogExpFunctions.xlogx.(y); loss.dims)) return entropy + cross_entropy end diff --git a/src/utils.jl b/src/utils.jl index e686b77e5..f565cce41 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -264,20 +264,28 @@ end Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero. + +!!! warning "Deprecated" + + Use `LogExpFunctions.xlogx` instead. """ @inline function xlogx(x::Number) - result = x * log(x) - return ifelse(iszero(x), zero(result), result) + __depwarn("`Lux.xlogx` is deprecated, use `LogExpFunctions.xlogx` instead.", :xlogx) + return LogExpFunctions.xlogx(x) end """ xlogy(x::Number, y::Number) Return `x * log(y)` for `y > 0`, and zero when `x == 0`. + +!!! warning "Deprecated" + + Use `LogExpFunctions.xlogy` instead. """ -@inline function xlogy(x::Number, y::Number) - result = x * log(y) - return ifelse(iszero(x), zero(result), result) +function xlogy(x::Number, y::Number) + __depwarn("`Lux.xlogy` is deprecated, use `LogExpFunctions.xlogy` instead.", :xlogy) + return LogExpFunctions.xlogy(x, y) end # Some functional forms of losses @@ -295,7 +303,7 @@ Broadcast.broadcastable(f::__Fix3) = Ref(f) end @inline function __poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} - return x - xlogy(y, x + __get_epsilon(T1, ϵ)) + return x - LogExpFunctions.xlogy(y, x + __get_epsilon(T1, ϵ)) end @inline function __msle_loss(x::T1, y::T2, ϵ) where {T1, T2} @@ -379,3 +387,5 @@ end @inline __eltype(::AbstractArray{<:ForwardDiff.Dual{T, V}}) where {T, V} = V @inline __reverse(x; dims=:) = reverse(x; dims) + +__depwarn(msg, sym) = Base.depwarn(msg, sym) # Prevents a type stability issue with Zygote diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index afc38b92d..472eb830f 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -1,4 +1,5 @@ @testitem "xlogx & xlogy" setup=[SharedTestSetup] tags=[:helpers] begin + # TODO: Remove in v1.0 using Lux: xlogx, xlogy using ForwardDiff, Zygote, Enzyme