Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gmm and tgmm trace_pallas caching #7921

Merged
merged 4 commits into from
Aug 30, 2024
Merged

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Aug 29, 2024

was able to reduce the tracing time of gmm from 6ms to 2.4 ms
image
image

@JackCaoG JackCaoG requested a review from alanwaketan August 29, 2024 00:27
@JackCaoG JackCaoG added the tpuci label Aug 29, 2024
@JackCaoG
Copy link
Collaborator Author

still need to add a test for the cache miss case.

global trace_pallas_arg_to_payload
# implcit assumption here that everything in kwargs is hashable and not a tensor,
# which is true for the gmm and tgmm.
hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work with different objects but with the same size, dtype and device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax_args are just meta tensors, I verified that same size will always map to the same hash. we are not hashing the id(static_argnames) so as long as the value is the same it will generate the same hash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's interesting. I guess if it works it works. Then why don't just use @cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that @cache cache the input, inputs of this functions are xla tensor, I felt like cache will try to access the value of those tensors. in here I only cache the JAX meta tensor.

Also let me reverify this with the real moe models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That's fair.

@JackCaoG
Copy link
Collaborator Author

verified in the profile that trace_pallas is cached.

@JackCaoG JackCaoG marked this pull request as ready for review August 30, 2024 00:55
@JackCaoG JackCaoG merged commit 8955571 into master Aug 30, 2024
23 checks passed
@JackCaoG JackCaoG deleted the JackCaoG/trace_pallas_cache branch August 30, 2024 00:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants