-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower the embedding op #6495
Lower the embedding op #6495
Conversation
Thanks @bhavya01! Is this ready to be reviewed? Also can you add a corresponding cpp unit test as it is a new op? |
@wonjoolee95 This PR should be ready for a review for #5982 |
Do we already have a cpp test for this op? Also wil |
We do have a cpp test for the op https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_5.cpp#L250 I think I will create a separate PR for embedding bag |
@bhavya01, can we ensure that the lowering get invoked correctly by adding the metric check in the Thanks, let's also work on the |
Added the check for metrics. EmbeddingBag is still WIP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I left a few comments as follow-up items. Not a blocker, so feel free to merge this PR. But let's follow-up on the items in the next PR (preferably the EmbeddingBag PR).
@@ -261,6 +261,9 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) { | |||
/*scale_grad_by_freq=*/false, | |||
/*sparse=*/false); | |||
AllClose(b, xla_b); | |||
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); | |||
ExpectCounterChanged("xla::embedding_symint", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: xla::embedding_syntint
-> xla::embedding*
. Not blocker, we can update in the EmbeddingBag PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also want to confirm one thing -- before implementing this lowering (i.e. without the changes in this PR), does this metric assertion fail as expected? Just wanted to make sure that the _symint
variant is getting recognized properly. It should, but just wanted to confirm.
But not a blocker, we can follow-up in the EmbeddingBag PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -3515,11 +3515,10 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight, | |||
return at::native::embedding_symint(weight, indices, padding_idx, | |||
scale_grad_by_freq, sparse); | |||
} | |||
// TODO: for now route to native, which dispatches supported XLA operations. | |||
// We need to make use of the TPU embedding core here eventually. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment (seems like it existed for a while) We need to make use of the TPU embedding core here eventually.
makes me a bit cautious about the Embedding
op. I assume this comment means to lower this op in a performant way, we somehow want to make use of the TPU embedding core? Just leaving this comment as a future reference. Also let's bring this comment back in the code so we are aware. Again, not a blocker, we can include as a follow-up in the next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Will add it back in a follow up PR
Reference: https://github.com/pytorch/pytorch/blob/113138aa5575301d914c18bd882d6ab3735aa18a/aten/src/ATen/native/Embedding.cpp#L37
The native pytorch implmentation also uses just the indices and weight matrix. Other options seem to be ignored.