-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optimized decoder for the deployment of DS2 #139
Conversation
接下来,有没有可能加入 Paddle 作为一个新的Layer,比如 |
@lcy-seso 这个是可以的,事实上我们也有这个计划,待decoder部分充分优化之后就会放到Paddle中去 |
放到PaddlePaddle里面,需要考虑Language Model。 |
TensorFlow 的CTCDecoder Layer 不需要语言模型吧?可以留一个回调函数接口。 |
这个版本是需要的,TF的LM是最后reranking,这个是decoder过程就要考虑 |
只能说我们在当前decode的过程中考虑了语言模型,而作为一个普通的decoder加入paddle,是完全可以和TensorFlow等同的 |
ac18ee5
to
fff62dc
Compare
14b16ac
to
908932f
Compare
…nto ctc_decoder_deploy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Please continue to review
deep_speech_2/infer.py
Outdated
@@ -84,14 +84,16 @@ def infer(): | |||
use_gru=args.use_gru, | |||
pretrained_model_path=args.model_path, | |||
share_rnn_weights=args.share_rnn_weights) | |||
|
|||
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
* num_processes: Number of threads for beam search. | ||
* cutoff_prob: Cutoff probability for pruning. | ||
* cutoff_top_n: Cutoff number for pruning. | ||
* ext_scorer: External scorer to evaluate a prefix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,68 @@ | |||
#ifndef PATH_TRIE_H | |||
#define PATH_TRIE_H | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have #pragma once
here, i think there's no need for #ifndef #define #endif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) | ||
cutoff_len, cum_prob = 0, 0.0 | ||
for i in xrange(len(prob_idx)): | ||
cum_prob += prob_idx[i][1] | ||
cutoff_len += 1 | ||
if cum_prob >= cutoff_prob: | ||
break | ||
cutoff_len = min(cutoff_top_n, cutoff_top_n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could move cutoff_len
into for loop as a stop condition like if (cum_prob >= cutoff_prob or cutoff_len >= threshold) break
.
I think min(cutoff_top_n, cutoff_top_n)
should be a typo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- It is not necessary.
- Corrected
@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split, | |||
pool = multiprocessing.Pool(processes=num_processes) | |||
results = [] | |||
for i, probs_list in enumerate(probs_split): | |||
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, | |||
nproc) | |||
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can comment more to clarify why using global ext_nproc_scorer
instead of passing ext_nproc_scorer
to ctc_beam_search_decoder
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would append the comment in later pr.
for (size_t i = 0; i < num_time_steps; ++i) { | ||
VALID_CHECK_EQ(probs_seq[i].size(), | ||
vocabulary.size() + 1, | ||
"The shape of probs_seq does not match with " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment lacks information like size of probs_seq should be equal to size of vocabulary plus one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified. Now this macro function will output where the error happens and the explicit expression
for (size_t i = 0; i < num_time_steps; ++i) { | ||
double max_prob = 0.0; | ||
size_t max_idx = 0; | ||
for (size_t j = 0; j < probs_seq[i].size(); j++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
j++
--> ++j
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} // end of loop over time | ||
|
||
// compute aproximate ctc score as the return score | ||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mark here, this can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -0,0 +1,75 @@ | |||
#ifndef CTC_BEAM_SEARCH_DECODER_H_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please unify #pragma once
or #ifndef #define #endif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
#include <utility> | ||
#include <vector> | ||
|
||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use using
in header file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
deep_speech_2/decoders/swig/setup.sh
Outdated
@@ -0,0 +1,21 @@ | |||
#!/bin/bash |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to #! /usr/bin/env bash
, see https://stackoverflow.com/questions/16365130/the-difference-between-usr-bin-env-bash-and-usr-bin-bash
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
deep_speech_2/decoders/swig/setup.sh
Outdated
@@ -0,0 +1,21 @@ | |||
#!/bin/bash | |||
|
|||
if [ ! -d kenlm ]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add error checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will append later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job. Thanks!
deep_speech_2/README.md
Outdated
@@ -176,6 +176,7 @@ Data augmentation has often been a highly effective technique to boost the deep | |||
|
|||
Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline. | |||
|
|||
### Inference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove L179
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
cutoff_prob=1.0, | ||
cutoff_top_n=40, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why to add cutoff_top_n
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a param used by Mandarin vocabulary cutoff
const float NUM_FLT_MIN = std::numeric_limits<float>::min(); | ||
|
||
// check if __A == _B | ||
#define VALID_CHECK_EQ(__A, __B, __ERR) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you consider using GLOG instead for simplicity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GLOG will conflict with the macro definition in openfst
. An improved macro function is used here instead.
#ifndef PATH_TRIE_H | ||
#define PATH_TRIE_H | ||
#pragma once | ||
#include <fst/fstlib.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include order follows Google Coding Style?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
std::vector<size_t> max_idx_vec; | ||
for (size_t i = 0; i < num_time_steps; ++i) { | ||
double max_prob = 0.0; | ||
size_t max_idx = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add const std::vector<double>& probs_seq_step = probs_seq[i];
And afterwards just using probs_seq_step
would be a little faster.
But I'm not sure whether the compiler has already done this implicitely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
std::vector<std::string> vocabulary, | ||
const double cutoff_prob = 1.0, | ||
const size_t cutoff_top_n = 40, | ||
Scorer *ext_scorer = NULL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NULL --> nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
const size_t num_processes, | ||
double cutoff_prob = 1.0, | ||
const size_t cutoff_top_n = 40, | ||
Scorer *ext_scorer = NULL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nullptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (dictionary != nullptr) delete static_cast<fst::StdVectorFst*>(dictionary); | ||
} | ||
|
||
void Scorer::load_LM(const char* filename) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_lm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
language model when alpha = 0. | ||
:type alpha: float | ||
:param beta: Parameter associated with word count. Don't use word | ||
count when beta = 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be careful of the indent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -21,9 +21,10 @@ python -u infer.py \ | |||
--num_conv_layers=2 \ | |||
--num_rnn_layers=3 \ | |||
--rnn_layer_size=2048 \ | |||
--alpha=0.36 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update this in examples/tiny
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Followed comments. Please continue to review
deep_speech_2/README.md
Outdated
@@ -176,6 +176,7 @@ Data augmentation has often been a highly effective technique to boost the deep | |||
|
|||
Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline. | |||
|
|||
### Inference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
cutoff_prob=1.0, | ||
cutoff_top_n=40, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a param used by Mandarin vocabulary cutoff
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) | ||
cutoff_len, cum_prob = 0, 0.0 | ||
for i in xrange(len(prob_idx)): | ||
cum_prob += prob_idx[i][1] | ||
cutoff_len += 1 | ||
if cum_prob >= cutoff_prob: | ||
break | ||
cutoff_len = min(cutoff_top_n, cutoff_top_n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- It is not necessary.
- Corrected
@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split, | |||
pool = multiprocessing.Pool(processes=num_processes) | |||
results = [] | |||
for i, probs_list in enumerate(probs_split): | |||
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None, | |||
nproc) | |||
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would append the comment in later pr.
#include "decoder_utils.h" | ||
#include "path_trie.h" | ||
|
||
std::string ctc_greedy_decoder( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
deep_speech_2/decoders/swig/setup.sh
Outdated
@@ -0,0 +1,21 @@ | |||
#!/bin/bash | |||
|
|||
if [ ! -d kenlm ]; then |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will append later
language model when alpha = 0. | ||
:type alpha: float | ||
:param beta: Parameter associated with word count. Don't use word | ||
count when beta = 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -21,9 +21,10 @@ python -u infer.py \ | |||
--num_conv_layers=2 \ | |||
--num_rnn_layers=3 \ | |||
--rnn_layer_size=2048 \ | |||
--alpha=0.36 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
deep_speech_2/decoders/swig/scorer.h
Outdated
void reset_params(float alpha, float beta); | ||
|
||
// make ngram | ||
std::vector<std::string> make_ngram(PathTrie *prefix); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that there would be an error when using const reference
ext_scorer->fill_dictionary(true); | ||
} | ||
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary); | ||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处应该不是一个简单的指针赋值,而是根据状态不同返回不同的match type
:
http://www.openfst.org/doxygen/fst/html/matcher_8h_source.html#l00041
Great job! LGTM |
47dc47b
to
f1cd672
Compare
Resolve #277 |
@kuke 在swig目录中执行sh setup.sh |
|
@kuke 非常感谢,对swig不了解,另外还有两个问题,1.在mac上执行会报这个问题, |
@fanlu hi~ 你好,考虑到这个PR已经merge,也不易被其它人搜索到。能否提一个issue,把所有遇到的问题都汇总到issue中呢? |
@lcy-seso 好的 |
Implement the CTC beam search decoder in C++ to speedup decoding. Compared with the prototype decoder in Python, this optimized decoder gets the identical decoding results and has the advantage of about 3x speedup in single thread (measured by the
time
module in Python) when given the same parameters.To further achieve real-time decoding for deployment in reality, the width of beam search can be set appropriately. An experiment is carried out to illustrate the effect of beam size on the WER and speed of decoding with 100 samples, and here are some results:
It is not hard to find that when
beam size < 200
, the average time of one sample's decoding is limited to1s
without a significant decay in WER. Therefore by setting a proper beam size in this range, the decoding in deployment can be completed within acceptable time.