Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogExpFunctions for losses #1866

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Conversation

ToucheSir
Copy link
Member

This package is already in the dep tree through multiple paths (Zygote -> ForwardDiff, StatsBase, etc.), so we might as well make use of it.

PR Checklist

  • Tests are added removed ;)
  • Entry in NEWS.md
    - [ ] Documentation, if applicable

@test xlogy(2, 3) ≈ 2.0 * log(3.0)
@inferred xlogy(2, 3)
@inferred xlogy(0, 1)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these tests pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested locally before removing and they do. https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/test/basicfuns.jl also looks like a strict superset of the Flux tests.

Comment on lines -16 to -19
function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question is whether there are any performance differences, and whether we care. IIRC the replacements have if else instead of ifelse, but perhaps the compiler sorts it out?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found JuliaStats/LogExpFunctions.jl#26. GPU is the big question mark, but if #1791 is any indication there may not be a difference there either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty comparable:

using Flux.Losses: xlogx as f_xlogx, xlogy as f_xlogy
using LogExpFunctions: xlogx as l_xlogx, xlogy as l_xlogy
using BenchmarkTools, CUDA

x, y, out = ntuple(_ -> rand(Float32, 100_000), 3);
cx, cy, cout = ntuple(_ -> CUDA.rand(Float32, 100_000), 3);

julia> @btime $out .= f_xlogx.($x);
  580.412 μs (0 allocations: 0 bytes)

julia> @btime $out .= l_xlogx.($x);
  580.883 μs (0 allocations: 0 bytes)

julia> @btime $out .= f_xlogy.($x, $y);
  622.826 μs (0 allocations: 0 bytes)

julia> @btime $out .= l_xlogy.($x, $y);
  657.381 μs (0 allocations: 0 bytes)

julia> @btime CUDA.@sync $cout .= f_xlogx.($cx);
  5.896 μs (7 allocations: 480 bytes)

julia> @btime CUDA.@sync $cout .= l_xlogx.($cx);
  5.832 μs (7 allocations: 480 bytes)

julia> @btime CUDA.@sync $cout .= f_xlogy.($cx, $cy);
  7.555 μs (23 allocations: 1.61 KiB)

julia> @btime CUDA.@sync $cout .= l_xlogy.($cx, $cy);
  7.114 μs (23 allocations: 1.61 KiB)

I did a couple of runs and there was a not insignificant amount of variability, but at least the relative times aren't too far off.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gret similar numbers.

@codecov-commenter
Copy link

Codecov Report

Merging #1866 (a3d8cd9) into master (7b56813) will decrease coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1866      +/-   ##
==========================================
- Coverage   84.50%   84.47%   -0.03%     
==========================================
  Files          21       21              
  Lines        1484     1475       -9     
==========================================
- Hits         1254     1246       -8     
+ Misses        230      229       -1     
Impacted Files Coverage Δ
src/losses/utils.jl 83.33% <100.00%> (-3.34%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7b56813...a3d8cd9. Read the comment docs.

@mcabbott
Copy link
Member

mcabbott commented Feb 9, 2022

I'm not strongly opposed but I wonder what we gain here, the functions are so simple.

Surely one could write a very fast xlogx(::Float32) via rational functions etc, since it doesn't need the hard part of log. But LogExpFunctions doesn't do that. And I doubt it's ever a bottleneck here.

Flux currently depends indirectly on this package (and its dependencies), but would ultimately like not to depend on Zygote, ForwardDiff, et. al. I lean towards a few lines of duplication being better than a tight web of every package depending on every other one.

@darsnack
Copy link
Member

darsnack commented Feb 9, 2022

What about gradients here? Not that we were testing this before.

@darsnack
Copy link
Member

darsnack commented Feb 9, 2022

I think the main benefit of external packages over code duplication is maintainability and testing. If it is a relatively lightweight dep and we use it for only a few functions, then I see no harm in having it. It'd be a different story if we take on a dep that provides many functions in Flux, making them intricately woven.

@mcabbott
Copy link
Member

mcabbott commented Feb 9, 2022

Right, I just feel that maintenance and testing of lines

  result = x * log(y)
  ifelse(iszero(x), zero(result), result)

here might be less hassle than maintenance of lines in Project.toml, breaking changes downstream, etc. If we want it never to change, then never changing it seems simpler than hoping some other package won't. I agree this is subjective though.

Besides such tradeoffs, the broadcasted gradients here aren't duplicated there. Ideally the future implementation of broadcasting will work off @scalar_rule but not yet.

@ToucheSir
Copy link
Member Author

I suppose I hold the opposite perspective. There is a not significant amount of, for lack of a better term, derelict code kicking around in FluxML packages. Think a lot of the utility functions here (some which are not tested 🙈) and certain adjoints in Zygote.

For this particular case, I don't think the extra dep is a problem for a few reasons. Firstly, if LogExpFunctions breaks then Flux is going to feel it either way, since it's on a critical path of some direct dependencies. Secondly, xlogx and xlogy is a small surface area to break, but have enough edge cases that I think the expanded test coverage in LogExpFunctions is a good thing. Lastly, there is a better chance of silently breaking changes being caught downstream because more packages rely on this functionality than Flux. We can all think of examples where routines here were doing the wrong thing but not erroring, so having that canary would allow us to focus effort on more core parts of the library.

About broadcasting, I had a look back through the blame and the adjoint in question was added 2 years ago. Nowadays, every AD we care about has a fast path for broadcasting LogExpFunctions.xlogy, whether that's through DiffRules or ChainRules (Diffractor). Perhaps if it needed the rrule to store some intermediate results for the pullback, but since it doesn't I don't see why the dedicated broadcast rule would be necessary.

@ToucheSir
Copy link
Member Author

ToucheSir commented Feb 14, 2022

@mcabbott @darsnack another temperature check on this now that #1863 is ready for review?

@darsnack
Copy link
Member

I'm still in favor overall

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants