Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit number of tokens per second for whisper. #1958

Merged
merged 1 commit into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
auto cross_kv = model_->ForwardEncoder(std::move(mel));

auto results = decoder_->Decode(std::move(cross_kv.first),
std::move(cross_kv.second));
std::move(cross_kv.second), num_frames);

auto r = Convert(results[0], symbol_table_);
s->SetResult(r);
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-whisper-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class OfflineWhisperDecoder {
* @return Return a vector of size `N` containing the decoded results.
*/
virtual std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v,
int32_t num_feature_frames) = 0;

virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
};
Expand Down
10 changes: 8 additions & 2 deletions sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig(

std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
Ort::Value cross_v,
int32_t num_feature_frames) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

Expand Down Expand Up @@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
int32_t n_text_ctx = model_->TextCtx();

std::vector<int32_t> predicted_tokens;
for (int32_t i = 0; i < n_text_ctx / 2; ++i) {

// assume at most 6 tokens per second
int32_t num_possible_tokens = num_feature_frames / 100 * 6;
num_possible_tokens = std::min<int32_t>(num_possible_tokens, n_text_ctx / 2);

for (int32_t i = 0; i < num_possible_tokens; ++i) {
if (max_token_id == model_->EOT()) {
break;
}
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
OfflineWhisperModel *model)
: config_(config), model_(model) {}

std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
std::vector<OfflineWhisperDecoderResult> Decode(
Ort::Value cross_k, Ort::Value cross_v,
int32_t num_feature_frames) override;

void SetConfig(const OfflineWhisperModelConfig &config) override;

Expand Down
Loading