Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qualify pairwise call #360

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

Crown421
Copy link
Member

While trying to understand how to work with Coverage I was struck by some lines that stubbornly did not get covered. Initially I thought that was due to missing tests, but even adding those did not change the coverage.

After some effort, I have found that these functions don't get called because the default Distances.pairwise is already enough. I managed to get those lines covered by qualifying the pairwise call in the tests, but it seems they are not actually necessary? Maybe it makes sense to remove those definitions from src altogether?

@willtebbutt
Copy link
Member

Interesting. I'm intruiged to see precisely which lines have changed coverage in this PR (which I guess we'll see once the Format suggestions check passes?). If there is indeed no need for certain methods, I agree that it would make sense to remove them. With things like ColVecs / RowVecs, we definitely need something though, because Distances.jl doesn't know about their existance. Hard to say without seeing which lines now have coverage though.

@Crown421
Copy link
Member Author

I have covered a few more lines, and run the formatter, so that should pass now, I hope.

I feel the solution could be to keep the tests in place, remove the qualifiers, and then also remove all non-covered lines. The test will then make sure that the functionality stays.

@willtebbutt
Copy link
Member

willtebbutt commented Aug 17, 2021

After some digging around in Distances, I discovered this PR JuliaStats/Distances.jl#194 , that I had previously missed.

It resolves the long-standing issue that we had whereby we couldn't add some methods that we needed to Distances.pairwise without committing type-piracy. It resolves this my having added the methods that we needed. I suspect this change means that, at some point in the (hopefully not too distant future) we'll be able to dispense with KernelFunctions.pairwise. I suspect that this is why stuff now works.

What concerns me now is just AD-related stuff. Specifically, ChainRules doesn't have rules for Distances stuff (because ChainRules only contains rules for Base, Core, and the standard libraries), and Distances isn't receptive to accepting ChainRulesCore as a dependency (because it doesn't want to have to maintain AD-related functionality, which is reasonable). There's an on-going discussion about glue packages, that will hopefully get some kind of resolution in the short to medium term (JuliaLang/Pkg.jl#1285), but it's not a thing yet, meaning that we can't fully dispense with KernelFunctions.pairwise at the minute.

This brings me around to what I think should be done:

  1. also add AD tests to the new methods that you've implemented, and
  2. then try to remove things, and see what breaks.

My suspicion is that the majority of the methods that we have will be necessary for AD to continue to work, and we're not currently testing that at all properly. The "correct" way to do that in the modern world is using https://juliadiff.org/ChainRulesTestUtils.jl/dev/#Testing-AD-systems .

This leaves a couple of ways forward:

  1. you extend this PR as per 1 and 2 above, or
  2. we merge this PR as-is, and open an issue about this, and deal with it at a later date.

I'm happy to go with either, it's just a question of what you've got the time / inclination to do.

@Crown421
Copy link
Member Author

Very interesting. I have seen the issue on conditional dependencies, it does not seem that it will resolve very soon.

In principle I am open to add the proper tests, as a way to learn about AD. I have been using it, but don't really understand it in any way. In practice, I have no idea how difficult it will be.
I have not added any methods, so I suppose it would mostly be writing new tests for existing functions (including figuring out which even need it).

@willtebbutt
Copy link
Member

I have not added any methods, so I suppose it would mostly be writing new tests for existing functions (including figuring out which even need it).

Indeed -- to be honest, it would make sense to test all of the methods (once you've got a function to test one of them, testing the others should be straightforward, at least in principle).

I think the way to go about it is just going to be to use the method I linked above (definitely don't try to roll your own testing code for AD). I suspect you can just use the rule config defined here: https://github.com/FluxML/Zygote.jl/blob/78bb9a3cad52de6e7c9a590d0f8ac4b6014a73f4/src/compiler/chainrules.jl#L4

@devmotion
Copy link
Member

I noticed that some recent changes in Distances make it necessary to analyze more carefully which methods are called. Distances.pairwise is not owned by Distances anymore. Instead it is owned by StatsBase (even though technically it is defined in StatsAPI). The dangerous part is that StatsBase defines a very general fallback definition of pairwise that does not exploit the structure of the metrics and inputs. However, it is difficult to spot potential performance problems since (as intended) no errors or warnings are shown if the potentially slow fallback is hit. It even caused problems with Distances.PreMetric, for some inputs the fallback was used instead of the implementations in Distances (I don't remember all details but I added some fixes to Distances a while ago). So while tests might not fail if specific methods are removed we should carefully evaluate if the removal causes any performance regressions.

@Crown421
Copy link
Member Author

After some exploration, I notice two issues:

  1. Since I am using Julia 1.7, ReverseDiff (and probably also ForwardDiff) segfault whenever I actually get to interesting parts. I understand this has something to do with a new ChainRules version?
  2. I am not sure I know what should be tested. I am trying to write AD tests for pairwise, for different input types (Vector of vectors, ColVecs, RowVecs ), but that seems to cause an issue with undefined zero. Am I actually on the right track?

@willtebbutt
Copy link
Member

Since I am using Julia 1.7, ReverseDiff (and probably also ForwardDiff) segfault whenever I actually get to interesting parts. I understand this has something to do with a new ChainRules version?

Oh interesting. I have no idea what to do about that. Would suggest using 1.6 if it's not obvious how to fix.

I am not sure I know what should be tested. I am trying to write AD tests for pairwise, for different input types (Vector of vectors, ColVecs, RowVecs ), but that seems to cause an issue with undefined zero. Am I actually on the right track?

Hmmm could you provide a MWE?

@Crown421
Copy link
Member Author

I think I have the same problem that makes the CI fail all the way through for Julia nightly. I have switched to 1.3 for now.

For development, I have started by looking at the following:

A = [rand(3) for _ in 1:5]
B = [rand(3) for _ in 1:7]
norm(pairwise(SqEuclidean(), A, B))

ReverseDiff.gradient(a->norm(pairwise(SqEuclidean(), a, B)), A)

which yields

ERROR: LoadError: MethodError: no method matching zero(::Type{Array{Float64,1}})

With my limited understanding, I have been able to find this discourse discussion, but before going down that road, I want to make sure this is even what I should be looking at.

@willtebbutt
Copy link
Member

Sorry for the slow response @Crown421

I would suggest starting with Zygote, and ChainRulesTestUtils.

In particular, if you take a look at the docs on testing AD systems, and utilise the method implemented in Zygote and Zygote's config. I personally generally find this to be the most debuggable bit of testing infrastructure.

I'm imagining something along the lines of

f(A, B) = pairwise(SqEuclidean(), A, B)
test_rrule(Zygote.ZygoteConfig(), f, A, B; rrule_f=rrule_via_ad)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants