-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support bmm fp8 #469
Conversation
Another suggestion is to move group gemm and bmm fp8 to a common gemm.py, we should also update the group_gemm.rst (to gemm.rst) as well. |
make sense |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for your contribution @zhyncs !
The documentation was not indexed properly in #469 , this PR fixes the issue.
🤖 I have created a release *beep* *boop* --- ## [0.1.6](v0.1.5...v0.1.6) (2024-08-27) ### SM75 Support Starting from [0.1.6](v0.1.5...v0.1.6), our pre-built wheels include experimental support sm75 (Turing architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080). ### API Changes #### `plan`/`run` Since [0.1.6](v0.1.5...v0.1.6) on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the new `plan`/`run` API. - `forward` is renamed to `run`, which is more precise and consistent with the naming convention of cutlass's python API. - `begin_forward` is renamed to `plan`, which is consistent with the naming convention of nvmath API. - `end_forward` is deprecated and has no effect after this PR. There is some slight difference between the old `forward` and the new `run` API: - All extra arguments such as `causal` and `logits_soft_cap` will be provided in `plan` (previously `begin_forward`) API, and cached until next `plan` call, and we only need to provide query and KV-Cache tensors in `run` API. The old `begin_forward`/`forward`/`end_forward` APIs are still functional, but we will gradually deprecate them in future releases. Check [#466](#466) for more details. #### `MultiLevelCascadeAttentionWrapper` Since [0.1.6](v0.1.5...v0.1.6) on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for cascade inference, which supports multi-level cascade inference where all levels' KV-Cache can be managed in a unified Paged KV-Cache. See [documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper) and [tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout) on API usage and layout explaination. The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and `BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in future releases. ### Features * sm75 support ([#448](#448), [#449](#449)) * add `MultiLevelCascadeAttentionWrapper` API ([#462](#462)) ([1e37989](1e37989)) * add accept num, emit num metric for ChainSpeculativeSampling ([#450](#450)) ([fa38b5e](fa38b5e)) * support bmm fp8 ([#469](#469)) ([f1c0b68](f1c0b68)) ### Refactor * refactor: replace `begin_forward`/`forward`/`end_forward` with `plan`/`run` [#466](#466) ### Misc * misc: improve error handling of sampling kernels ([#456](#456)) ([0dce178](0dce178)) ### Performance Improvements * slight optimization on f16->f8 fragment layout swizzling ([#453](#453)) ([0d61871](0d61871)) * slight optimization on fragment layout swizzle ([#458](#458)) ([7c397cb](7c397cb)) * use persistent kernel for merging attention states ([#459](#459)) ([be6bf5b](be6bf5b)) ### Acknowledgement We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance of speculative sampling operator, [@merrymercy](https://github.com/merrymercy) on API change suggestion and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas implementation. --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
torch.bmm
doesn't support fp8 andtorch._scaled_mm
doesn't support 3d, so I write this one. @yzh119 cc @merrymercy @Ying1123 @ispobockThanks @yzh119 for assisting with debug.
AType: fp8 e4m3, fp8 e5m2
BType: fp8 e4m3, fp8 e5m2
DType: bf16, fp16
Does not support both AType and BType fp8 e5m2. ref https://docs.nvidia.com/cuda/cublas/#cublasltmatmul
works on H100