Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use xlogx and xlogy from LogExpFunctions #796

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Expand Down Expand Up @@ -90,6 +91,7 @@ GPUArraysCore = "0.1.6"
Hwloc = "3.2.0"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
LogExpFunctions = "0.3.28"
Logging = "1.10"
LossFunctions = "0.11.1"
LuxCore = "0.1.16"
Expand Down
1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using FastClosures: @closure
using ForwardDiff: ForwardDiff
using Functors: Functors, fmap
using GPUArraysCore: GPUArraysCore, @allowscalar
using LogExpFunctions: LogExpFunctions
using LossFunctions: LossFunctions
using MacroTools: MacroTools, block, combinedef, splitdef
using Markdown: @doc_str
Expand Down
31 changes: 1 addition & 30 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ CRC.@non_differentiable __state_if_stateful(::Any)
CRC.@non_differentiable __set_state!(::Any...)
CRC.@non_differentiable __update_bn_state(::Any...)
CRC.@non_differentiable __warn_mismatch(::Any...)
CRC.@non_differentiable __depwarn(::Any, ::Any)

# Utilities
function CRC.rrule(::typeof(_eachslice), x, d::Val)
Expand Down Expand Up @@ -88,33 +89,3 @@ end
end
return __fused_agg(sum, lfn, x, y), ∇lfn
end

function CRC.rrule(::typeof(xlogx), x::Number)
iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent())
logx = log(x)
∇xlogx = @closure Δ -> (NoTangent(), @thunk(Δ*(logx + true)))
return x * logx, ∇xlogx
end

function CRC.rrule(
::typeof(Broadcast.broadcasted), ::typeof(xlogx), x::AbstractArray{<:Number})
logx = log.(x)
y = x .* logx
∇xlogx = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*(logx .+ true)))
return y, ∇xlogx
end

function CRC.rrule(::typeof(xlogy), x::Number, y::Number)
iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent())
logy = log(y)
∇xlogy = @closure Δ -> (NoTangent(), @thunk(Δ*logy), @thunk(Δ * x/y))
return x * logy, ∇xlogy
end

function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy),
x::AbstractArray{<:Number}, y::AbstractArray{<:Number})
logy = log.(y)
z = x .* logy
∇xlogy = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*logy), @thunk(Δ .* x./y))
return z, ∇xlogy
end
3 changes: 2 additions & 1 deletion src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ end

function __unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y)
cross_entropy = __unsafe_apply_loss(loss.celoss, ŷ, y)
entropy = loss.agg(sum(xlogx.(y); loss.dims)) # Intentional broadcasting for Zygote type stability
# Intentional broadcasting for Zygote type stability
entropy = loss.agg(sum(LogExpFunctions.xlogx.(y); loss.dims))
return entropy + cross_entropy
end

Expand Down
22 changes: 16 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,28 @@ end

Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get
zero.

!!! warning "Deprecated"

Use `LogExpFunctions.xlogx` instead.
"""
@inline function xlogx(x::Number)
result = x * log(x)
return ifelse(iszero(x), zero(result), result)
__depwarn("`Lux.xlogx` is deprecated, use `LogExpFunctions.xlogx` instead.", :xlogx)
return LogExpFunctions.xlogx(x)
end

"""
xlogy(x::Number, y::Number)

Return `x * log(y)` for `y > 0`, and zero when `x == 0`.

!!! warning "Deprecated"

Use `LogExpFunctions.xlogy` instead.
"""
@inline function xlogy(x::Number, y::Number)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
function xlogy(x::Number, y::Number)
__depwarn("`Lux.xlogy` is deprecated, use `LogExpFunctions.xlogy` instead.", :xlogy)
return LogExpFunctions.xlogy(x, y)
end

# Some functional forms of losses
Expand All @@ -295,7 +303,7 @@ Broadcast.broadcastable(f::__Fix3) = Ref(f)
end

@inline function __poisson_loss(x::T1, y::T2, ϵ) where {T1, T2}
return x - xlogy(y, x + __get_epsilon(T1, ϵ))
return x - LogExpFunctions.xlogy(y, x + __get_epsilon(T1, ϵ))
end

@inline function __msle_loss(x::T1, y::T2, ϵ) where {T1, T2}
Expand Down Expand Up @@ -379,3 +387,5 @@ end
@inline __eltype(::AbstractArray{<:ForwardDiff.Dual{T, V}}) where {T, V} = V

@inline __reverse(x; dims=:) = reverse(x; dims)

__depwarn(msg, sym) = Base.depwarn(msg, sym) # Prevents a type stability issue with Zygote
1 change: 1 addition & 0 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
@testitem "xlogx & xlogy" setup=[SharedTestSetup] tags=[:helpers] begin
# TODO: Remove in v1.0
using Lux: xlogx, xlogy
using ForwardDiff, Zygote, Enzyme

Expand Down
Loading