Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random: Uniqueness Types? #401

Closed
srush opened this issue Dec 30, 2020 · 8 comments
Closed

Random: Uniqueness Types? #401

srush opened this issue Dec 30, 2020 · 8 comments

Comments

@srush
Copy link
Contributor

srush commented Dec 30, 2020

Hi dex,

I was reading about how Haskell had just added Linear/Unique Types (https://en.wikipedia.org/wiki/Uniqueness_type) and was curious if that would work with the Dex type system?

In particular, I constantly am messing up random keys in Jax/Dex, and would love if they type system could just prevent me from reusing them without splitting.

@dougalm
Copy link
Collaborator

dougalm commented Dec 31, 2020

Great question. We've toyed with linear types in Dex (you can see some vestigial lollipops, --o, here and there). We're still exploring if/how to include them. Here are three possible use cases.

  1. Guaranteeing that certain array updates occur in-place. This is a traditional use of linear/uniqueness types. It's how Futhark does in-place updates, for example. So far, we've chosen to take the effect route instead. But both solutions have their advantages and we're still trying to understand the tradeoffs.

  2. Avoiding reuse of PRNG keys. As you point out, the PRNG system in Dex (and JAX) is a big foot-gun because it's so easy to accidentally reuse keys. I'd definitely like to explore a type-oriented solution here. Alternatively, we could try using the effect system, which might also take care of some of the plumbing. The conventional approach is to use a state-like effect that updates a PRNG state behind the scenes. But we don't like that because it breaks parallelism, which is one of the features we love about splittable PRNGs. But there's a parallelism-preserving effect that's almost right. As a Haskell monad, it would look like this.

newtype Random a = Random { runRandom :: Key -> a }

instance Monad Random where
  f >>= g = Random $ \k ->
              let (k1, k2) = splitKey k
              in runRandom (g $ runRandom f k1) k2
  return x = Random $ const x

The problem is that it breaks the law! We don't even have m >>= return = m anymore because >>= forces a split. It still obeys the law "up to re-rolling the dice". But I don't know how bad that is.

  1. Dex's AD system works with functions that are "linear" in the linear algebra sense (scaling the input scales the output and so on). It turns out that you can check that linearity using a linear-in-the-linear-type-sense type system. That's why our AD functions have types that include lollipops:
linearize       : (a ->  b) -> a -> (b & a --o b)
transposeLinear : (a --o b) -> (b --o a)
vjp             : (a ->  b) -> a -> (b & b --o a) =

They used to be checked but they're currently just documentation. The advantage of actually checking them is that it lets you safely expose linearTranspose as a standalone function, because you can check that the function you apply it to is truly linear (in both senses of the word).

Hope that gives some context. The summary is: linear types could have many applications in Dex but we don't have concrete plans for them yet.

@srush
Copy link
Contributor Author

srush commented Dec 31, 2020

This is extremely interesting! Thanks so much for taking the time to write it up.

@srush
Copy link
Contributor Author

srush commented Dec 31, 2020

One more question @dougalm . I think I am missing something basic in your point 3.

Am I right to assume that in Dex a --o b implies that the function is linear (as in algebra), but not that it is a linearly typed function as in :

image

I can't think of any usability reason why a linear function would need to have the Linear typing restriction as well? (Although maybe there are some ways to exploit that for efficiency? https://www.cl.cam.ac.uk/~nk480/numlin.pdf)

@apaszke
Copy link
Collaborator

apaszke commented Jan 4, 2021

This looks resolved, so I'll close it. Can't seem to find the button to convert the issue to a discussion for some reason 😕

@apaszke apaszke closed this as completed Jan 4, 2021
@apaszke
Copy link
Collaborator

apaszke commented Jan 4, 2021

Oh, and answering your question @srush: yes, it is a different notion of linearity. For example, f x = x + x is linear, and would be accepted by the Dex linearity checker, but in general would be rejected by a linear typing system that doesn't consider + to be any special operation.

@dougalm
Copy link
Collaborator

dougalm commented Jan 8, 2021

Sorry, I meant to respond to this earlier.

Am I right to assume that in Dex a --o b implies that the function is linear
(as in algebra), but not that it is a linearly typed function as in [consumes
its argument exactly once]

Actually, it is meant to imply that it is linearly typed in the sense of consuming its argument exactly once, which also implies that it's linear in the linear algebra sense.

Before explaining what I mean I'll establish some terminology, because "linear" and "function" are both confusingly overloaded here.

I'll use "syntactic function" to mean a lambda term in a program. And I'll use "mathematical function" to mean a mathematical object mapping inputs to outputs. Usually we think of a syntactic function as denoting a mathematical function, but of course it's syntactic functions that we have to work with in a compiler.

I'll use "logically linear" to mean linear in the linear logic sense (equivalently, in the linear type sense). It's often described in terms of "consuming" an argument or a logical premise. We say that a syntactic function is logically linear if it obeys the linear typing rules. I'll use "algebraically linear" to refer to mathematical functions on vector spaces that obey f (x + y) == f x + f y and f (c * x) = c * (f x).

Ok, with that out of the way, I think there are two suprising claims we're making about linearity and AD. First, we're claiming that a logically linear type checker can prove that a syntactic function denotes a mathematical function that's algebraically linear. This isn't obvious at all, and I wouldn't have believed it if Gordon Plotkin hadn't patiently explained it to us.

The intuition is that we can think of "consuming an input" as "multiplying the input into the function's result". Then it seems plausible that consuming an input exactly once means that you're algebraically linear in that input. For example, here are three functions for which the logical and algebraic notions of linearity clearly coincide.

-- Logically linear and algebraically linear (consumes input once)
\x. 2 * x

--- Neither logically linear nor algebraically linear (consumes input twice)
\x. x * x

-- Neither logically linear nor algebraically linear (doesn't consume input at all)
-- (Note that this function is "affine", in both the logical and algebraic senses!)
\x. 2

To make this work formally in a standard logically linear type system, we just need the rule that the type of (*) is Float --o Float --o Float (i.e. it's bilinear) and then use a standard textbook logically linear type system on top of that.

But what about+? This is the tricky bit. We need to work with the uncurried form: (+) : (Float & Float) --o Float. It's linear in the pair as a whole. What's different about the constructor for pairs is that the linear typing rule says that you need to consume the arguments exactly once on each side of the pair. (In the turnstile formalism, the environment is duplicated on each side). This sort of product type is sometimes called the "additive conjunction" in the literature but Gordon reckons that's too highfalutin and prefers to just call it the "Cartesian product". Here are some examples using +.

-- Logically linear and algebraically linear (consumes input once on each side)
\x. (+) (x, x)

--- Neither logically linear nor algebraically linear (only consumes input on one side)
\x. (+) (x, 2)

Interestingly, we can also make an uncurried version of (*) work if we introduce a different sort of product type, the "tensor product". AFAICT, we can't have an curried version of (+).

Hopefully that gives some support to the idea that logical linearity implies algebraic linearity if we set up the primitive typing rules in the right way. Of course, it doesn't go the other way. Just because a syntactic function's denotation is algebraically linear, it doesn't mean that the syntactic function itself is logically linear. For example:

-- Algebraically linear but *not* logically linear
\x. log (exp x)

But notice that even though this is algebraically linear, it's not something we'd know how to transpose in an AD system. So the second surprising claim is this: logical linearity isn't required for algebraic linearity but it is required for mechanical transposition of syntactic functions in the AD sense (at least, for the style of transposition that I know how to do). The reason is that we use the "consumed once" property when we do transposition of primitives like *, both in JAX and in Dex. We always have exactly one side that we're transposing with respect to and we take advantage of that in our rules.

Hope that's clearer! We should write this up one of these days.

@mattjj
Copy link
Collaborator

mattjj commented Jan 17, 2023

Hope that's clearer! We should write this up one of these days.

@dougalm finally succeeded in writing this down! You Only Linearize Once: Tangents Transpose to Gradients by Radul et al.

@srush
Copy link
Contributor Author

srush commented Jan 17, 2023

Awesome! I'll check it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants