Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][WIP] Tensor Expression level automatic differentiation #1996

Closed
sgrechanik-h opened this issue Oct 25, 2018 · 25 comments
Closed

[RFC][WIP] Tensor Expression level automatic differentiation #1996

sgrechanik-h opened this issue Oct 25, 2018 · 25 comments

Comments

@sgrechanik-h
Copy link
Contributor

I'm working on automatic differentiation at the level of compute expressions, and I would like to share some progress and hear any comments. Currently the automatic differentiation works well enough for some operations, so that it is possible to train a simple model, here is a tutorial on how to do this. Yet, for many operations the performance is unacceptable, but I'm working on it.

My implementation mostly follows this paper. In this notebook I describe how my implementation works internally and give a list of operations which are known to work or not to work. Basically, the AD consists of two parts:

  • The automatic differentiation itself which simply differentiates expressions according to the well-known rules and produces inefficient expressions. The code is here.
  • A set of transformations to optimize the resulting inefficient expressions. The code is here.

All transformations work on the level of compute expressions (before scheduling). Their general goal is to eliminate summation over zeros by moving up conditional expressions of the form cond ? val : 0 and then using them to simplify iteration domains of reductions. Hopefully, these transformations may be useful for some other tasks besides AD when they are powerful enough. Currently the main problem is that they don't understand modular arithmetic (which is needed for differentiating dilated and strided convolutions and for the flattening operation).

The git branch
The squashed commit
The tutorial on training a simple model
The notebook describing some internals

@tqchen tqchen changed the title [RFC][WIP] TVM-level automatic differentiation [RFC][WIP] Tensor Expression level automatic differentiation Oct 25, 2018
@tqchen
Copy link
Member

tqchen commented Oct 31, 2018

This looks great. I would also encourage us to think a bit more on how this can play together with the current system. Specifically, while being able to run AD in tensor expression level end to end is fun, we may not necessarily use this method, as AD on high-level IR still offers advantages of things like general graph optimizations and pick of algorithms(winograd vs direct), so technically we would still encourage the use of high-level IR AD.

On the other hand, this would become practical if we could have a proposal on how to write quick expression op, that can be integrated together with other high-level IR and its AD system, this would enable a powerful to write custom operators and automatically generate gradient for it

@jroesch
Copy link
Member

jroesch commented Oct 31, 2018

@MarisaKirisame after PLDI could you provide more feedback.

@MarisaKirisame
Copy link
Contributor

@jroesch yeah

@sgrechanik-h
Copy link
Contributor Author

@tqchen A couple of months ago I tried to integrate this autodiff into NNVM, the attempt (now abandoned) is still preserved in some old commits:
https://github.com/sgrechanik-h/tvm/blob/109d2f785d3e73e56665fbc987f6bd0dc5823d60/nnvm/src/top/tensor/gradient.cc
The idea was to add a special operation called gradient which accepted as attributes the original operation name and the original operation's attributes. This operation's compute function consisted in automatically differentiating the original operation's compute. But I suppose there must be a better approach.

Being able to easily define new high-level ops just from tensor expressions sounds like a good idea. I definitely missed something like this in NNVM.

Also I think it would be nice to be able to override operations' attributes, like gradients, from python code. At least, this would be great for comparing automatic gradients with manually written gradients.

@tqchen
Copy link
Member

tqchen commented Oct 31, 2018

FYI, we are moving toward NNVMv2(relay) and this is a good time to think about related questions as scaffolding are being done. Relay already support override op's attributes from python

@sergei-mironov
Copy link
Contributor

On the other hand, this would become practical if we could have a proposal on how to write quick expression op, that can be integrated together with other high-level IR and its AD system, this would enable a powerful to write custom operators and automatically generate gradient for it

Indeed, I think that higher level IRs will always be able to decide which method to use. As far as Relay uses TVM under the hood, it should be in good position to make a choice: either substitute operations during building the backward path (e.g. Winograd -> matmul), switch to hand-written gradients or use a semantically-correct default implementation which is provided by the current approach.

@sergei-mironov
Copy link
Contributor

sergei-mironov commented Nov 19, 2018

Continuing discussion about training, started in #2122

@junrushao1994 @were @sgrechanik-h @szha I suggest we continue discussion related to training here.

So how could we do manual (or automatic) scheduling on a automatically generated backprop?

We plan to do the following regarding scheduling:

  1. Provide some automatic scheduling out of the box
  2. Provide API for manual scheduling of gradients. We could allow programmers to access gradient nodes by their source nodes.
  3. Provide API to override node's automatic gradient with manually-defined gradient if it is required.

@were
Copy link
Contributor

were commented Nov 19, 2018

For 1., 2.
As far as I imagine, the API will look like:

diff = tvm.diff([op1, op2, op3])
sch = tvm.create_schedule(diff.op)

The differentiation is also some kind of op node so that we can unify the API to TVM general flow.
I am also working on some new op node scheduling thing. If this part requires some help, I am willing to contribute to this issue.

@sgrechanik-h
Copy link
Contributor Author

I've updated our automatic differentiation branch. Now the result of differentiating flatten is acceptable, and operations like max pool work better as well. We have also improved the API, the simple use-case look pretty much the same up to function renaming:

[dw1, dw2, dw3] = tvm.differentiate(loss, [w1, w2, w3])

(The function differentiate is defined here, and here is a small tutorial).
However, now it is possible to get individual adjoints from the result:

x = tvm.placeholder((32, 3, 28, 28), name='x')
w1 = tvm.placeholder((10, 3, 3, 3), name='w1')
t1 = topi.nn.conv2d(x, w1, 1, 0, 1)
t2 = topi.nn.flatten(t1)
t3 = topi.sum(t2)

res = tvm.differentiate(t3)
res.adjoints[t1]

(It may be useful for manually scheduling intermediate tensors).
And it is also possible to override Jacobian computation for some tensors:

def mydiff(out, inp, head):
    return tvm.compute(inp.shape, lambda ax0, ax1, ax2, ax3: head[ax0, ax3 + ax2*26 + ax1*676])

res = tvm.differentiate(t3, [x, w1], manual={(t2, t1): mydiff})

(Which may be useful when autodiff does poor job).

@tqchen
Copy link
Member

tqchen commented Feb 16, 2019

@sgrechanik-h Thanks for the hard work on proposing and bringing in the implementation, this seems to be an RFC that many folks are interested in.

I have no doubt on the possibility to make a correct and efficient tensor expression auto diff. The key issue we need to address first is how can we make a clear set of APIs and namespaces so that:

  • We have a clear separation between relay's grad and the current tensor expression one
  • Have a clear indication for frontend user to use relay's grad, while use tensor expression one for the operator impl
  • Have a consistent API so a user does not get confused when using one or another.
  • Because the primary support for end to end grad is in relay, we want to name the files clearly(e.g. tvm/autodiff-> tvm/expr_grad.h) to make that distinction, so it is more clear that relay is the entry point of tvm stack's differentiation.

These discussions and decision are important because they will affect how users will use the API, as per https://docs.tvm.ai/contribute/code_review.html#deliberate-on-api-and-data-structures

I suggest we discuss them, reach consensus before we move on

cc @merrymercy @jroesch @yzhliu @junrushao1994 @ZihengJiang @grwlf @dmlc/tvm-committer

@tqchen
Copy link
Member

tqchen commented Feb 16, 2019

@sgrechanik-h would be great if you can edit the first post and list the choices or options clearly so that everyone can discuss in a coordinated way, see #2535

@sgrechanik-h
Copy link
Contributor Author

Here are the things to discuss I came up with.

Naming

Concerning names, here is a list of things in the tensor expression AD that we might want to rename:

  • C++ files include/tvm/autodiff.h and src/pass/autodiff.cc, and a python module tvm.autodiff. autodiff is the most important thing to rename. I'd suggest something like tensordiff, tensor_expr_diff, tensor_expr_ad. I think using just expr without tensor may be a little bit confusing, since there are exprs in Relay too. Using grad vs diff is a matter of taste, I guess, diff seems more general though.
  • The main user-facing python function tvm.autodiff.differentiate, reexported as tvm.differentiate for convenience. I would suggest keeping its name as is, however we probably shouldn't reexport it, to avoid confusion.
  • Some more or less internal C++ functions and data structures (Differentiate, Jacobian, DiffBuildingBlock, here). They live under tvm::ir:: in C++ and are exported into tvm.autodiff in python. I would also suggest keeping their names, however we probably should create a separate namespace for them in C++, like tvm::autodiff.
  • The relay/autodiff integration file src/relay/op/autodiff_integration.cc. Its name will be changed when we rename autodiff. Not sure about the _integration part.
  • The relay/autodiff integration op autogenerated_gradient. Not sure, suggestions are welcome.
  • Functions relay._ir_pass.AutogeneratePrimalGradient and relay._ir_pass.AutogeneratePrimalGradientForAll. Not sure.

API

The main thing to discuss here is the python function tvm.autodiff.differentiate. Recently I decided that its gradient overriding mechanism is rather difficult to use, and tried to improve it. Not sure if it's better now, however let's discuss the new version. Its signature and docstring may be found here, Here is just short description:

def differentiate(output, inputs=None, head=None, override=None, fdiff=None): 
    """Perform reverse-mode automatic differentiation.  

Parameters:

  • output : Tensor - The tensor to differentiate.
  • inputs : List[Tensor] - The list of input tensors. When the list is empty or None, will perform differentiation wrt all tensors the output depends on.
  • head : Tensor, optional - The adjoint of the output. By default full Jacobians will be computed (or simply gradients in the case of a scalar output).
  • override - A dict describing how to override differentiation for certain tensors (see the link).
  • fdiff - The function performing differentiation and multiplication of single tensors, by default tvm.autodiff.DiffBuildingBlock is used.

The returned value of the function is an instance of DifferentiationResult. It mimics a list of gradients (corresponding to inputs) but also contains two other fields: adjoints and adjoint_summands to access intermediate adjoints.

Several issues to discuss:

  • The name of the head parameter. May be it'd better be called output_adjoint or output_gradient or, more precisely, gradient_wrt_output, although I tend to avoid the word "gradient" since we don't require the final output to be scalar.
  • The gradient overriding mechanism is still quite obscure, maybe we should remove it from this function completely (the corresponding parameters may be retained in the C++ function Differentiate).
  • The fdiff parameter, however, is much easier to understand, so I suggest retaining it.
  • Probably we should support multiple outputs, since differentiating them one by one may be less efficient.

Another important function is the C++ counterpart, tvm::ir::Differentiate This function is for developers though. Currently my only suggestion is that we might want to support multiple outputs and heads.

Relay integration API

The user mostly needs two functions: (code here):

  • AutogeneratePrimalGradient(op_name, plevel) - override FPrimalGradient for the given operation (with the given plevel, 100 by default)
  • AutogeneratePrimalGradientForAll(plevel) - override FPrimalGradient for every existing operation. The plevel here is 5 by default, so it shouldn't override gradients for operations with FPrimalGradient already defined.

The operation representing the generated gradients currently takes a single tuple as input and returns a single tuple as output (here). The attributes of the operation are as follows:

  • original_op : Op - The original operation.
  • original_attrs : Attrs - The attributes of the original operation.
  • original_out_type : Type, optional - The type of the original expression.

However, I think this operation should be considered an implementation detail, and the stability of its attributes, inputs and outputs shouldn't be relied upon.

@tqchen
Copy link
Member

tqchen commented Feb 18, 2019

  • We can use gradient or differentiation to indicate the gradient. Given that most API(tensorflow, pytorch), and relay's high level API use gradient. I suggest we use the name gradient.
  • Namespace how about: tvm.expr.gradient?

@sgrechanik-h
Copy link
Contributor Author

  • We can use gradient or differentiation to indicate the gradient. Given that most API(tensorflow, pytorch), and relay's high level API use gradient. I suggest we use the name gradient.

Do you mean renaming autodiff.{h,cc} -> gradient.{h,cc}?

  • Namespace how about: tvm.expr.gradient?

Wouldn't it make more sense to put it in tvm.tensor.gradient since it works on tensors, not only on expressions? Also placing a submodule gradient into tvm.expr (or tvm.tensor) would require creating a directory tvm/expr/ and moving the contents of tvm/expr.py into tvm/expr/__init__.py

@yzhliu yzhliu mentioned this issue Mar 2, 2019
28 tasks
@sgrechanik-h
Copy link
Contributor Author

sgrechanik-h commented Jun 19, 2019

Hello everyone. I want to tell you about the current status of tensor expression automatic differentiation. The latest version can be found here. The main improvements are as follows:

  • I've implemented a solver for systems of linear integer equations. This considerably improves performance of certain operations like dilated and strided convolutions.
  • I've redesigned the zero elimination module. Now there is a class Domain which represents an iteration domain (a set of integer tuples, usually convex), and most of the functions transform domains into other domains (returning objects of the class DomainTransformation representing two domains and variable mappings).
  • I've moved to the new simplifiers. This was important because the Halide simplifier assumes that division is Euclidean which leads to generation of incorrect code.

However there are several problems which are TVM-related and should be addressed before creating pull-requests:

  1. TVM bound inference sometimes leads to such tensor bound expansion that the tensors can't fit into memory. This is a known problem ([TVM] Fix GatherBound to avoid allocating too much #2104), however nobody knows how to solve it, as it seems. In the linked branch I use a simple fix which however breaks some tests by triggering a strange-looking assertion.
  2. Certain parts of TVM still use Euclidean division which sometimes results in incorrect code being generated. Hopefully, this problem will be mostly fixed by @tqchen in [ARITH] Migrate simplifier to new infra #3368. (Although the PR is still unfinished, I use its slightly modified version in the autodiff branch).

@MarisaKirisame
Copy link
Contributor

I see fundamental problem in this PR.

Jacobian(Y, W):
tensor compute.jacobian{0x165b360}[0] : float32 [32, 3000, 3000, 10000]
axes (i : [0, 31], j : [0, 2999], jac_i0 : [0, 2999], jac_i1 : [0, 9999])
Reduction
identity [0.000000f]
lhs [x.der] rhs [y.der]
combiner [(x.der + y.der)]
axes (k : [0, 9999])
condition (uint1)1
source[0] = (X(i, k)*float32(((jac_i0 == j) && (jac_i1 == k))))

This is a really, really big tensor, and the approach this PR take has a "cliff of death" performance chart.

This PR then rely on simplification to eliminate all those tensor. If any tensor is not eliminated(which seems to be the case for more complex tensor) the performance will be very bad.

Reverse mode automatic differentiation should only calculate vector jacobian product.

The jacobian of Y, W, should be dW times jacobian Y W. the Jacobian should simply never be manifested.

can this be fixed so the algorithm will not be algorithmically slower without optimization?

@yzhliu
Copy link
Member

yzhliu commented Mar 6, 2020

@MarisaKirisame could you elaborate "The jacobian of Y, W, should be dW times jacobian Y W" ? not sure I correctly understand the symbol you use.
I think the main challenge is to infer the bound for Jacobian's axis, under the scenario where output axes can be arbitrary linear combination of its input tensors' axes

@MarisaKirisame
Copy link
Contributor

@yzhliu let me explain it another way.
suppose we work only at the scalar level. people had proved that reverse mode only take 3x more computation. This does not need any optimization - the gradient of f will only be a (small, typicallty less then 10) constant factor times slower then the computation of f itself.

In this PR, the gradient of f might be MANY times more expensive.
This is because it is calculating the jacobian, rather then the product of a vector/matrix/tensor with that jacobian, which can be fused, so the computation can be expressed in a way simpler form.

@sergei-grechanik
Copy link
Contributor

@MarisaKirisame One approach that I know of is to differentiate by literally reversing the computation, replacing loads with stores (with += actually) and vice versa. This is usually considered a worse approach because it leads to synchronization problems, but it should be better for some ops, and I guess the worst case performance shouldn't be that horrible.

@MarisaKirisame
Copy link
Contributor

@sergei-grechanik that is what ppl do. the problem of parallel ad is shared read become shared write, but you can pull optimization tricks to turn shared write into a single write (for example, https://people.csail.mit.edu/tzumao/gradient_halide/).
I think this is a smarter approach, for when their optimization failed, they only have write sync, but when ours optimization failed, we have giant tensors.

@sergei-grechanik
Copy link
Contributor

@MarisaKirisame Yes, I think you are right. I guess implementing this approach requires extending tensor expressions with an additional construct or working on a different level.

@MarisaKirisame
Copy link
Contributor

It probably require working on the tensor ir, as the tvm compute level is pure.

@sergei-grechanik
Copy link
Contributor

@MarisaKirisame What do you think about adding a scatter-like reduction to TVM's tensor expressions? (I mean for(i : ...) A[f(i)] += g(i), i.e. the lhs is allowed to have an arbitrary index expression in the subscript). It would still be pure.

@MarisaKirisame
Copy link
Contributor

@sergei-grechanik look great! another way to do it is to add a pass that turn scatter into gather, effectively doing halide's autodiff. so, the ir is not pure, but it is after some more processing.

@tqchen tqchen closed this as completed May 11, 2020
@tqchen
Copy link
Member

tqchen commented May 11, 2020

The te AD has been landed in the mainline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants