Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower the embedding op #6495

Merged
merged 5 commits into from
Feb 12, 2024
Merged

Lower the embedding op #6495

merged 5 commits into from
Feb 12, 2024

Conversation

bhavya01
Copy link
Collaborator

@bhavya01 bhavya01 commented Feb 7, 2024

Reference: https://github.com/pytorch/pytorch/blob/113138aa5575301d914c18bd882d6ab3735aa18a/aten/src/ATen/native/Embedding.cpp#L37

The native pytorch implmentation also uses just the indices and weight matrix. Other options seem to be ignored.

@wonjoolee95
Copy link
Collaborator

wonjoolee95 commented Feb 8, 2024

Thanks @bhavya01! Is this ready to be reviewed? Also can you add a corresponding cpp unit test as it is a new op?

@bhavya01 bhavya01 requested a review from wonjoolee95 February 8, 2024 00:24
@bhavya01 bhavya01 self-assigned this Feb 8, 2024
@bhavya01
Copy link
Collaborator Author

bhavya01 commented Feb 8, 2024

@wonjoolee95 This PR should be ready for a review for #5982

@wonjoolee95
Copy link
Collaborator

Do we already have a cpp test for this op? Also wil embedding_bag also be lowered in this PR or will that be in a follow-up PR?

@bhavya01
Copy link
Collaborator Author

bhavya01 commented Feb 8, 2024

We do have a cpp test for the op https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_5.cpp#L250

I think I will create a separate PR for embedding bag

@wonjoolee95
Copy link
Collaborator

@bhavya01, can we ensure that the lowering get invoked correctly by adding the metric check in the embedding cpp test like https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_5.cpp#L246?

Thanks, let's also work on the embedding_bag as we have the customer request on that op. Thanks for working on this!

@bhavya01
Copy link
Collaborator Author

bhavya01 commented Feb 9, 2024

Added the check for metrics. EmbeddingBag is still WIP.

Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I left a few comments as follow-up items. Not a blocker, so feel free to merge this PR. But let's follow-up on the items in the next PR (preferably the EmbeddingBag PR).

@@ -261,6 +261,9 @@ TEST_F(AtenXlaTensorTest, TestEmbedding) {
/*scale_grad_by_freq=*/false,
/*sparse=*/false);
AllClose(b, xla_b);
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::embedding_symint",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: xla::embedding_syntint -> xla::embedding*. Not blocker, we can update in the EmbeddingBag PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also want to confirm one thing -- before implementing this lowering (i.e. without the changes in this PR), does this metric assertion fail as expected? Just wanted to make sure that the _symint variant is getting recognized properly. It should, but just wanted to confirm.

But not a blocker, we can follow-up in the EmbeddingBag PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, without the change this check fails.

image

@@ -3515,11 +3515,10 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight,
return at::native::embedding_symint(weight, indices, padding_idx,
scale_grad_by_freq, sparse);
}
// TODO: for now route to native, which dispatches supported XLA operations.
// We need to make use of the TPU embedding core here eventually.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment (seems like it existed for a while) We need to make use of the TPU embedding core here eventually. makes me a bit cautious about the Embedding op. I assume this comment means to lower this op in a performant way, we somehow want to make use of the TPU embedding core? Just leaving this comment as a future reference. Also let's bring this comment back in the code so we are aware. Again, not a blocker, we can include as a follow-up in the next PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Will add it back in a follow up PR

@bhavya01 bhavya01 merged commit 657b692 into master Feb 12, 2024
18 checks passed
amithrm pushed a commit to amithrm/xla that referenced this pull request Mar 1, 2024
bhavya01 added a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants