-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Q: What types can be used with autodiff? #454
Comments
Yeah, base types (Float), tables, tuples and records are the only things that will work at the moment I think. Adding support for general ADTs is on our roadmap, but it is a bit subtle, because the type of e.g. jvp :: (a -> b) -> (a, Tangent a) -> (b, Tangent b) with Supporting ADTs is on the roadmap, but is a bit more involved. Below are two reasons that we know of. If you have any thoughts or suggestions do let us know! ADTs have named constructorsUnlike tuples and records, ADTs have named constructors, with names given by the users. This is a problem, because the data MyADT = MyConstructor Float Int Now, what should its tangent type be? We could "downgrade it" into a tuple of In short, we'd like the tangent types of ADTs to have some correspondence to the names given by the user to the original ADT. On the other hand, we don't want to put the burden of writing another, extremely similar ADT by hand. ADTs can be sum typesDifferentiating through sum types is even more involved, because then the jvp :: (a -> b) -> (x:a ** Tangent x) -> (y:b ** Tangent y) where Of course an alternative here is to make |
As always this answer is so good, and taught me so much more than the question I asked. Your time is highly appreciated. |
I'd like to extend @apaszke's excellent answer about associated To also address the question: what types work with differentiation? An analogyThere's an analogy between types in programming languages and mathematical spaces. jvp :: (a -> b) -> (x:a ** Tangent x) -> (y:b ** Tangent y) I'll reuse the Programming languages
Differential geometry (math)
def gradient_descent(parameters, gradient):
"""Performs gradient descent, updating parameters given their gradients.
Gradient descent effectively has the type signature of exponential map.
This implementation updates parameters in-place, as done in modern deep learning frameworks.
You could imagine an equivalent functional implementation that returns new parameters.
Args:
parameters: bundle of arrays. Like `model.parameters()` from PyTorch.
gradient: bundle of tangent arrays. Like `[p.grad for p in model.parameters()]` from PyTorch.
"""
parameters -= learning_rate * gradient Putting it togetherLet's visualize this analogy. "First-class Differentiable Programming" @ Probabilistic & Differentiable Programming Summit, June 2019Visualization: differentiable manifolds, differentiable function and differential function, exponential map.
We can think of types in programming languages as mathematical spaces.
"Intro to Differentiable Swift" @ Swift for TensorFlow Open Design Meeting, March 2020Animated visualization ✨: what types work with differentiation? From differentiable manifolds to a
|
* Add recommended reading: google-research/dex-lang#454 * Spelling, it is hard yo. Co-authored-by: Dan Zheng <danielzheng@google.com> Co-authored-by: Dan Zheng <danielzheng@google.com>
@dan-zheng asked me for some feedback on this thread. The answers above seems to embody a unfortunate and popular perspective, namely that forward mode and reverse mode AD are different questions requiring different vocabulary and techniques. Instead, a single, simple notion of differentiation and a single API suffice, and a single simple AD algorithm can handle forward, reverse, and other mixed modes with ease and without complicated operational details like graphs, mutation, and “backpropagation”. Instead of changing the algorithm, choose a suitable representation of linear maps. A good choice for low-dimensional domains is functions that are linear, while a good choice for low-dimensional codomains is the dual of such functions, where the fundamental building blocks of functions are defined dually, with composition reversed, projections become injections and vice versa, duplication and combination (addition) trade places, and curried scalar multiplication becomes itself. See The simple essence of automatic differentiation for details, including proofs. The algorithm is calculated from a simple, precise specification by solving a standard collection of algebra problems. The Microsoft Research talk is probably the most accessible explanation. Another unfortunate choice in the first formulation above is the type of
With these two changes, your API would become more precise, i.e., you’ve statically eliminated many invalid representations. The remaining invalid representations can be eliminated via dependent types. These changes also lead to fixing the first serious problem I mentioned above of treating various “modes” of differentiation as if they were different questions (specifications), rather than different answers (algorithms). The key is in realizing that there are many valid linear map representations you can use for |
Thanks @dan-zheng and @conal this is really interesting. Since there seem to be a bunch of people following along, going to make a study thread here to discuss this paper. #494 High-level: it sounds like many of these ideas may be out-of-scope for the type system of Dex? And there is a more practical question of "how to auto-define and name simple tangent types". However, it still feels really important. |
@conal Thanks for the feedback. Before I get to the technical part of my answer, I wanted to ask you to limit yourself from expressing judgements on anyone discussing any topic on our issue tracker. We’re trying to build an inclusive community and welcome people from many backgrounds. In particular we don’t care if they want to understand the process of differentiation in terms of graph traversals, backpropagation, or category theory. I’m sure you too could learn a lot from them, if only you open yourself to their perspective. I’m aware that it’s quite easy to get misread as you post things online and I’m sure that you’re writing your comments in good faith, but please be careful about how your message can be understood by others. Moving on to technical material. We completely agree with your suggestions (1. and 2.) and it is in fact how AD is implemented in Dex. The builtin functions we expose for that purpose are:
Note that our type system even features a linear arrow that can verify that user-defined functions are truly linear and transposable. Finally, I think it is worth noting that this approach is no silver bullet, which is likely the reason why many AD systems that do care a lot about forward-mode performance cannot take the path you’re outlining. In particular, just like one can prove theorems that forward- and reverse-mode AD can be carried out in the same order of computational complexity as the input program, forward-mode has the additional benefit of being able to preserve the same order of memory complexity. But, this is conditional on being able to produce a program where the evaluation of the non-linear function is interleaved with the linear part, which is far from easy when linearization is considered fundamental (it would require whole program optimization and very aggressive code motion in many program representations). See our LAFI abstract for an outline of how the ideas you’re suggesting can be pushed even further to alleviate this issue (the gist of it is that we actually do make |
@apaszke Thanks for this response. Message received about tone. I originally wrote these notes just for @dan-zheng as a response to his inquiry and a follow-on to some of our past conversations. I regret sending the notes as they were to a group with whom I don’t have such a shared context.
Great. It sounds like we’re closer than I thought. Seeing this explicit I think I’m suggesting something different, which is to have only what you call “
The reason I’m aware of for combining the primal function (of type |
Thanks for explaining! I now see how what you are proposing is slightly different than what we do. I'll try to paraphrase your point and describe how it compares to our approach, but of course please do point out any inaccuracies in my comparison. If I understand your point correctly, you say the vocabulary that transforms and composes the linear maps in both forward- and reverse-mode is the same, and I agree that it can be made so (as you carefully outline in your paper). I would be tempted to say that there is a type-class your linear map representation has to implement in order to be compatible with the process of differentiation. Going down that path has the benefit of using a single program transformation for both modes, but the downside of using two sets of rules (as you have to implement the type-class twice, once for each of the two linear map types). But, this largely misses out on the close correspondence between the rules used to perform forward- and reverse-mode. In our own jargon, we like to say that forward-mode rules implement JVPs (jacobian-vector products), while reverse-mode rules implement VJPs (vector-jacobian products). I like those names much, because they highlight their relation: each one is a transposed version of the other. Because of that, in Dex the only AD mode we really support (and have to implement rules for) is forward-mode, while reverse-mode is obtained not via reusing the same differentiation process with a different rule set, but through a program transposition transformation that is always valid to perform on the functions produced by forward-mode AD. So yes, for the purpose of differentiation we do assume a particular representation of linear maps, which in our case is encoded in what we call a structurally linear program (this also has some interesting connections to linear logic as @dougalm wonderfully explained in one of our issues). But this doesn't prevent us from getting reverse-mode in the end, because we've simply found a path that doesn't require us to reengage the AD machinery, with the added benefit of having a significantly smaller rule set than necessary for both AD modes. About your second question, this is mostly not about sharing (which is critical too, but not precluded by the signature linearize :: (a -> b) -> (a -> (b, T a -o T b))
jvp :: (a -> b) -> (a, T a) -> (b, T b)
jvp f (x, t) =
let (y, f') = linearize f x
in (y, f' t) If the function |
I'm curious what type of functions one should be able
grad
over. i.e. what is the implicit restriction ona
def grad (f:a->Float) (x:a) : a = snd (vjp f x) 1.0
Currently it seems to work for tables and tuples, but other things crash for me (for instance custom
data
types).Was playing around with something Flax-like for grouping params and functions, but I think this might be the wrong path given #331 and because I am not smart enough to figure out how to unpack a tuple a Params to a Param of a tuples.
The text was updated successfully, but these errors were encountered: