-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support gmm and tgmm trace_pallas caching #7921
Conversation
still need to add a test for the cache miss case. |
global trace_pallas_arg_to_payload | ||
# implcit assumption here that everything in kwargs is hashable and not a tensor, | ||
# which is true for the gmm and tgmm. | ||
hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work with different objects but with the same size, dtype and device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax_args
are just meta tensors, I verified that same size will always map to the same hash. we are not hashing the id(static_argnames)
so as long as the value is the same it will generate the same hash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's interesting. I guess if it works it works. Then why don't just use @cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my understanding is that @cache
cache the input, inputs of this functions are xla tensor, I felt like cache
will try to access the value of those tensors. in here I only cache the JAX meta tensor.
Also let me reverify this with the real moe models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. That's fair.
verified in the profile that |
was able to reduce the tracing time of gmm from 6ms to 2.4 ms

