-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sparse optimizers #1285
Comments
Here is an example test case for optimization when we have sparse gradients, done by "knocking out" the x or y coordinate on the computed gradient and putting the rest in a sparse gradient (I did it in a kind of janky way; ways to improve this code stylistically much appreciated!) When we get around to testing with weight decay I'll have to replace the assertEqual with a more approximate test, since the strategy we plan on using won't give us exactly the same results as the dense version.
|
Discussion with @adamlerer: if we amortize sparse updates in this way, we need some sort of "flushing" operation which a user can call to update the tensor to the "correct value" (since we may have accumulated updates on parameters that haven't been applied to the underlying tensor itself.) This is a change to the API for optimizers (a new flush function a user has to call at the end of optimizing) but since no one is actually optimizing on sparse models, this doesn't seem like a bad thing to ask them to do. |
Now that I feel that I have a much better grasp on what is going on with the optimizers, let me give an expanded "state of play":
It will probably be interesting to see what TF is doing for all of the other optimizers we have. |
Towards fixing #1285. This is a WIP commit to get the ball rolling on code review (I am sure I have done great violence to the various coding standards of your project.) Things to be done: - Continue adding sparse support for other parameters and optimizers - Add some more tests, including a unit test ensuring that a single step is what we expect Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Towards fixing #1285. This is a WIP commit to get the ball rolling on code review (I am sure I have done great violence to the various coding standards of your project.) Things to be done: - Continue adding sparse support for other parameters and optimizers - Add some more tests, including a unit test ensuring that a single step is what we expect Signed-off-by: Edward Z. Yang <ezyang@fb.com>
We should probably introduce lazy adam. See https://twitter.com/haldaume3/status/901434195117051904 and https://discuss.pytorch.org/t/sparse-embedding-failing-with-adam-torch-cuda-sparse-floattensor-has-no-attribute-addcmul-/5589 as data-points |
In the meantime, a documentation update would be helpful to document the sparse parameter to nn.Embedding and mention where it can't be used (weight decay and some optimizers). |
I assume weight decay should work the same way as @ezyang points out lazy adam should be done. No reason to decay weights in batches that they're not, used, right? |
Well there might be a reason - if you do it lazily, the weight vector is never up to date, so the statistics you might compute (e.g. weight distribution) would be incorrect |
My 2c: I thought about it a bit and this naive lazy update kind of makes sense for Adam. If you apply all of the updates (m, v, p) sparsely, then indeed this should behave pretty similarly (probably even superior to) dense Adam. I would argue that you shouldn't be using a momentum method with sparse updates, but whatever... Weight decay is different. It's a component of the loss, and if you do it sparsely, you are no longer computing weight decay, you're computing something else. There's a very good reason why weight decay should be applied to all parameters: the whole point of L2 regularization is that it should drive down the value of parameters g_i for which dL/dg_i is small. If you have a parameter which is almost always "sparse" (i.e. dL/dg_i is 0 most of the time), then it should be pushed down strongly (i.e. weight decay should be applied at each gradient update, although this can be done in an amortized fashion to save on compute). |
[apologies for hijacking the conversation] Just wanted to point out that we recently developed a technique to deal with sparse updates here and here which avoids all the bookkeeping required for the lazy updates. It was originally developed for asynchronous algorithms (where the lazy updates are not an option) but it works also great for sequential algorithms. I believe this technique should be applicable also to the case of Adam but I haven't looked into it yet. It's more of a research project, but if someone is interesting in this, just drop me a line. [/end hijacking the conversation] |
Well... I do want to point out that the Tensorflow implementation works extremely well for embedding layers. It's very helpful to have both momentum methods and weight decay in embedding layers, but the current pytorch sparse approach doesn't work at all in this case. If I remove weight decay and use adagrad (which works with sparse layers) I don't get good results. With tensorflow (using keras) I've been able to build text classification (sentiment analysis) models that are far more accurate, have less parameters, and train more quickly and reliably using their approach to sparse updates. To be clear: they're far better regardless of whether I'm using weight decay - their approach to sparse updates seems to be giving much better results (and I've confirmed that other details like initializers are the same). The goal here shouldn't be to perform only as well as dense adam (which performs very poorly in this case), but as well as tensorflow's sparse adam (and sparse momentum, and sparse weight decay, etc). Whilst I'm not sure about @adamlerer's theory, I do think that a lazy approach should at least be benchmarked against the simple gradient masking approach to see when/if it's actually better. If it's not better for real examples, then it's probably not a good idea. Just my $0.02... |
@jph00 thanks for the anecdotal results, that's really good to know. Are you sure that it's LazyAdam in particular that leads to good embedding layers (vs. some other TF vs PT difference)? I.e. have you trained the same TF with DenseAdam / Adagrad / SGD and observed substantially worse embeddings? To clarify, you say you use sparse momentum + weight decay, but I don't see any weight decay option in tensorflow/tensorflow@819a690d , so how are you applying weight decay? Is it dense or sparse? Re: my thoughts, I guess my main request is that if S is a sparse tensor then it should be the case that |
It's definitely not just lazyadam - I'm guessing it's just sparse_apply in general. But I don't know the exact source of the difference; at this point all I can say for sure is that keras, with tf backend, trains simple models with embeddings extremely quickly and extremely accurately compared to pytorch, and allows the use of weight decay and any optimizer. I've only briefly scanned the keras code, but it seems to generally be just passing stuff down to tf with minimal fiddling, so I was guessing that the reason must be in how tf is handling the embeddings. But I don't know that for sure. I also noticed that setting sparse=True in pytorch resulted in better results for simple models than setting sparse=False. So even in pytorch something about that approach is resulting in improvements. Which is to say: trying to match dense results doesn't seem like it should be the goal here... |
Jeremy, is it possible to share your Keras & PyTorch scripts (if they are not proprietary and stuff). I'm down to take a deeper look, and work on PyTorch features to get to parity. |
Of course, I'd love to! I don't have anything proprietary - it'll all be part of the course... Let me spend some time today trying to create something as clear and concise as I can. I'll add a link here when done. |
Ugh, I spent the last 4 hours trying to create a clear example, and ended up finding that the key issue is actually in the difference between torchtext and keras padding (keras pre-pads by default, torchtext post-pads). I had no idea this could be such a big issue - but it fully explains the differences I was seeing! So sorry to waste your time with my incorrect finger-pointing at the sparse updates. I'm not going to fully back-track and say that updates to just the embeddings that were used isn't a good idea; I don't have the data either way on this. I hope sometime after the next course to have time to go back over this and do some proper experiments. I will say however that the fact that keras/tf handles these models so quickly regardless of weight decay, adam, etc, is pretty damn cool, although I'm not really sure how they do it! |
Would like to see SparseRMSprop. Talked to Alex Pritzel about reproducing Neural Episodic Control, and he mentioned specifically that they update the Differentiable Neural Dictionary with sparse updates in TensorFlow, so only accessed keys have gradient averages calculated. Given that the DND is used in other work as well this optimiser would be nice to have. |
Sparse RMSprop is feasible with a library I just released |
Computing weight decay this way is actually common when optimizing recommendation models like L2-regularized matrix factorization with SGD. The state of the PyTorch optimizers w.r.t. sparsity makes this a lot harder than it needs to be. Attempting to apply SGD weight decay to a sparse embedding layer currently gives the error:
Since there isn't another behavior currently implemented for this case, it seems like allowing sparse weight decay here wouldn't really hurt. Maybe the main concern would be avoiding a performance regression in the dense case? |
Summary: nvfuser code update: 1. Tuning heuristics on schedulers for reduction/normalization kernels; 2. bfloat16 on IO tensor support; 3. Refactored memory format support, now we can support dimension collapsing with non-coherent input tensors with different memory format. e.g. channels last tensor input to batch normalization. Note that we are currently limiting memory format to only Contiguous and Channels last; 4. Refactored nvfuser graph partitioning in `graph_fuser.cpp`, separated node merge and profile node API. Updated `profiling_record.cpp`. Things that are reverted from our local branch: 1. changes on some entries in autodiff 2. aten::gelu with approximation 3. native_dropout(_backward) Pull Request resolved: pytorch#67943 Reviewed By: ngimel Differential Revision: D32288709 Pulled By: dzhulgakov fbshipit-source-id: fc9491182ea7e0158bc112c66f096823c588eaf1
* FusedRMSNorm/"T5LayerNorm" based on FusedLayerNorm (pytorch#1274) * FusedRMSNorm based on FusedLayerNorm * refactor duplicated kernels * delete comments * delete comments * cleanup * cleanup * cleanup, fixed clobbering forward_affine_mixed_dtypes * fix pybind naming and add MixedFused test * undo skipping * check elementwise_affine * Update tests/L0/run_fused_layer_norm/test_fused_layer_norm.py Oof, nice catch, thanks Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> * fix and generate docs for FusedRMSNorm (pytorch#1285) * [FusedRMSNorm doc] document where epsilon is added (pytorch#1295) * [FusedRMSNorm doc] add epsilon to formula * correct * better wording * Fix some bugs * Optimize HostRMSNormGradient and HostApplyRMSNorm for AMD GPUs * Fix NaN issues in FusedRMSNorm * Update test_fused_layer_norm.py * Skip test_fused_layer_norm.TestAutocastFusedRMSNorm on ROCm * Use at::cuda::warp_size() instead of at::cuda::getCurrentDeviceProperties()->warpSize Co-authored-by: eqy <eddiey@nvidia.com> Co-authored-by: Masaki Kozuki <masaki.kozuki.2014@gmail.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Wang, Yanyao <yanyao.wang@amd.com>
Co-authored-by: Wang, Yanyao <yanyao.wang@amd.com>
Now that sparse tensors are mostly working (#1147), it would be awesome if all the optimizers worked with sparse tensors. This requires some cleverness to do amortized updates to the parameters. For example, for weight decay, you would do something like (pseudocode):
Note this isn't exactly equivalent (to be equivalent you'd need to apply the weight decay before the forward pass, not after backwards), but it's a good approximation. You can do the same thing for momentum.
I'm guessing the same thing works for Adam/Adamax as well but I haven't worked through the equations. https://arxiv.org/pdf/1412.6980.pdf
@ezyang expressed interest in working on this.
The text was updated successfully, but these errors were encountered: