Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add accept num, emit num metric for ChainSpeculativeSampling #450

Merged
merged 3 commits into from
Aug 17, 2024

Conversation

LiuXiaoxuanPKU
Copy link
Contributor

No description provided.

@zhyncs zhyncs requested review from yzh119 and zhyncs August 16, 2024 06:33
@zhyncs zhyncs self-assigned this Aug 16, 2024
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution @LiuXiaoxuanPKU ! I added a few comments and suggestions.

include/flashinfer/sampling.cuh Outdated Show resolved Hide resolved
python/csrc/sampling.cu Outdated Show resolved Hide resolved
python/csrc/sampling.cu Outdated Show resolved Hide resolved
include/flashinfer/sampling.cuh Outdated Show resolved Hide resolved
@yzh119
Copy link
Collaborator

yzh119 commented Aug 16, 2024

This file has to be updated:

torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic);

if you change the function signature.

@LiuXiaoxuanPKU
Copy link
Contributor Author

@yzh119 Thanks for the review. I just fixed the comments, feel free to take another pass. Thanks!

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@yzh119 yzh119 merged commit fa38b5e into flashinfer-ai:main Aug 17, 2024
yzh119 added a commit that referenced this pull request Aug 27, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.6](v0.1.5...v0.1.6)
(2024-08-27)

### SM75 Support

Starting from
[0.1.6](v0.1.5...v0.1.6),
our pre-built wheels include experimental support sm75 (Turing
architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080).

### API Changes

#### `plan`/`run`

Since
[0.1.6](v0.1.5...v0.1.6)
on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the
new `plan`/`run` API.
- `forward` is renamed to `run`, which is more precise and consistent
with the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All extra arguments such as `causal` and `logits_soft_cap` will be
provided in `plan` (previously `begin_forward`) API, and cached until
next `plan` call, and we only need to provide query and KV-Cache tensors
in `run` API.

The old `begin_forward`/`forward`/`end_forward` APIs are still
functional, but we will gradually deprecate them in future releases.

Check [#466](#466) for
more details.

#### `MultiLevelCascadeAttentionWrapper`

Since
[0.1.6](v0.1.5...v0.1.6)
on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for
cascade inference,
which supports multi-level cascade inference where all levels' KV-Cache
can be managed in a unified Paged KV-Cache.

See
[documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper)
and
[tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout)
on API usage and layout explaination.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in
future releases.

### Features

* sm75 support
([#448](#448),
[#449](#449))
* add `MultiLevelCascadeAttentionWrapper` API
([#462](#462))
([1e37989](1e37989))
* add accept num, emit num metric for ChainSpeculativeSampling
([#450](#450))
([fa38b5e](fa38b5e))
* support bmm fp8
([#469](#469))
([f1c0b68](f1c0b68))

### Refactor

* refactor: replace `begin_forward`/`forward`/`end_forward` with
`plan`/`run`
[#466](#466)

### Misc

* misc: improve error handling of sampling kernels
([#456](#456))
([0dce178](0dce178))

### Performance Improvements

* slight optimization on f16->f8 fragment layout swizzling
([#453](#453))
([0d61871](0d61871))
* slight optimization on fragment layout swizzle
([#458](#458))
([7c397cb](7c397cb))
* use persistent kernel for merging attention states
([#459](#459))
([be6bf5b](be6bf5b))

### Acknowledgement

We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance
of speculative sampling operator,
[@merrymercy](https://github.com/merrymercy) on API change suggestion
and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas
implementation.

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants