Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower embedding bag forward only #6951

Merged
merged 2 commits into from
Apr 22, 2024
Merged

Lower embedding bag forward only #6951

merged 2 commits into from
Apr 22, 2024

Conversation

bhavya01
Copy link
Collaborator

No description provided.

@bhavya01 bhavya01 requested a review from wonjoolee95 April 22, 2024 18:57
@bhavya01 bhavya01 closed this Apr 22, 2024
@bhavya01 bhavya01 reopened this Apr 22, 2024
Copy link
Collaborator

@wonjoolee95 wonjoolee95 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@bhavya01 bhavya01 merged commit 46919a4 into master Apr 22, 2024
21 checks passed
@bhavya01 bhavya01 deleted the embeddingbag branch April 22, 2024 22:28
Comment on lines +1293 to +1308
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
XLANativeFunctions::_embedding_bag_forward_only(
const at::Tensor& weight, const at::Tensor& indices,
const at::Tensor& offsets, bool scale_grad_by_freq, int64_t mode,
bool sparse, const c10::optional<at::Tensor>& per_sample_weights,
bool include_last_offset, int64_t padding_idx) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if (mode == 1 || scale_grad_by_freq || sparse || padding_idx != -1) {
return at::native::call_fallback_fn<
&xla_cpu_fallback,
ATEN_OP(_embedding_bag_forward_only)>::call(weight, indices, offsets,
scale_grad_by_freq, mode,
sparse, per_sample_weights,
include_last_offset,
padding_idx);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bhavya01 Is there a reason only _embedding_bag_forward_only (instead of also lowering _embedding_bag)? What about the fallback condition? Is there a specific reason we are not lowering that, too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, the code for the other rest of _embedding_bag was overly complicated so we measured it as out of scope for this PR. @bhavya01, please correct me if I'm wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right! We still need to lower _embedding_bag.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ysiraichi, would lowering _embedding_bag entirely be something that you need?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it would be nice having that. I won't be working on it right now, since I have things with higher priority to be done. Anyway, I have this draft branch that can help us lowering + maintaining composite operations. It does mainly 2 things:

  • Allow us to write lowerings for composite operations in Python
  • Check for implemented decompositions (PyTorch + PyTorch/XLA) whenever we hit the fallback function. If it finds one, use that to decompose the operation into possibly already decomposed operations
    • At the moment, PyTorch decompositions are only used in dynamo
    • Allow us to use PyTorch decompositions on non-dynamo experiments
    • Allow us to use PyTorch/XLA specific decompositions on both dynamo and non-dynamo experiments

I won't be working on this PR for a while, so if anyone wants to take over it, I don't mind. If this PR gets merged, we would probably have an easier time lowering operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants