Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimized decoder for the deployment of DS2 #139

Merged
merged 48 commits into from
Sep 18, 2017

Conversation

kuke
Copy link
Collaborator

@kuke kuke commented Jun 29, 2017

Implement the CTC beam search decoder in C++ to speedup decoding. Compared with the prototype decoder in Python, this optimized decoder gets the identical decoding results and has the advantage of about 3x speedup in single thread (measured by the time module in Python) when given the same parameters.

To further achieve real-time decoding for deployment in reality, the width of beam search can be set appropriately. An experiment is carried out to illustrate the effect of beam size on the WER and speed of decoding with 100 samples, and here are some results:
2017-07-06 8 48 14

It is not hard to find that when beam size < 200, the average time of one sample's decoding is limited to 1s without a significant decay in WER. Therefore by setting a proper beam size in this range, the decoding in deployment can be completed within acceptable time.

@kuke kuke changed the title Add optimized decoder for deployment Add optimized decoder for the deployment of DS2 Jun 29, 2017
@lcy-seso
Copy link
Collaborator

lcy-seso commented Jul 5, 2017

接下来,有没有可能加入 Paddle 作为一个新的Layer,比如CtcDecodingLayer?

@kuke
Copy link
Collaborator Author

kuke commented Jul 6, 2017

@lcy-seso 这个是可以的,事实上我们也有这个计划,待decoder部分充分优化之后就会放到Paddle中去

@pkuyym
Copy link
Contributor

pkuyym commented Jul 6, 2017

放到PaddlePaddle里面,需要考虑Language Model。

@lcy-seso
Copy link
Collaborator

lcy-seso commented Jul 6, 2017

TensorFlow 的CTCDecoder Layer 不需要语言模型吧?可以留一个回调函数接口。

@pkuyym
Copy link
Contributor

pkuyym commented Jul 6, 2017

这个版本是需要的,TF的LM是最后reranking,这个是decoder过程就要考虑

@kuke kuke requested review from xinghai-sun and pkuyym July 6, 2017 11:52
@kuke
Copy link
Collaborator Author

kuke commented Jul 6, 2017

只能说我们在当前decode的过程中考虑了语言模型,而作为一个普通的decoder加入paddle,是完全可以和TensorFlow等同的

@kuke kuke force-pushed the ctc_decoder_deploy branch 2 times, most recently from ac18ee5 to fff62dc Compare July 27, 2017 02:04
Copy link
Collaborator Author

@kuke kuke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Please continue to review

@@ -84,14 +84,16 @@ def infer():
use_gru=args.use_gru,
pretrained_model_path=args.model_path,
share_rnn_weights=args.share_rnn_weights)

vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,68 @@
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have #pragma once here, i think there's no need for #ifndef #define #endif

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
cutoff_len = min(cutoff_top_n, cutoff_top_n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could move cutoff_len into for loop as a stop condition like if (cum_prob >= cutoff_prob or cutoff_len >= threshold) break.
I think min(cutoff_top_n, cutoff_top_n) should be a typo.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It is not necessary.
  2. Corrected

@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can comment more to clarify why using global ext_nproc_scorer instead of passing ext_nproc_scorer to ctc_beam_search_decoder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would append the comment in later pr.

for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1,
"The shape of probs_seq does not match with "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment lacks information like size of probs_seq should be equal to size of vocabulary plus one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified. Now this macro function will output where the error happens and the explicit expression

for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
for (size_t j = 0; j < probs_seq[i].size(); j++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

j++ --> ++j

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

} // end of loop over time

// compute aproximate ctc score as the return score
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark here, this can be removed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,75 @@
#ifndef CTC_BEAM_SEARCH_DECODER_H_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please unify #pragma once or #ifndef #define #endif

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

#include <utility>
#include <vector>

using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use using in header file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,21 @@
#!/bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,21 @@
#!/bin/bash

if [ ! -d kenlm ]; then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add error checking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will append later

Copy link
Contributor

@xinghai-sun xinghai-sun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job. Thanks!

@@ -176,6 +176,7 @@ Data augmentation has often been a highly effective technique to boost the deep

Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.

### Inference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove L179

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cutoff_prob=1.0,
cutoff_top_n=40,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why to add cutoff_top_n?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a param used by Mandarin vocabulary cutoff

const float NUM_FLT_MIN = std::numeric_limits<float>::min();

// check if __A == _B
#define VALID_CHECK_EQ(__A, __B, __ERR) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you consider using GLOG instead for simplicity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GLOG will conflict with the macro definition in openfst. An improved macro function is used here instead.

#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
#include <fst/fstlib.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Include order follows Google Coding Style?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::vector<size_t> max_idx_vec;
for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0;
size_t max_idx = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add const std::vector<double>& probs_seq_step = probs_seq[i];
And afterwards just using probs_seq_step would be a little faster.
But I'm not sure whether the compiler has already done this implicitely?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

std::vector<std::string> vocabulary,
const double cutoff_prob = 1.0,
const size_t cutoff_top_n = 40,
Scorer *ext_scorer = NULL);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NULL --> nullptr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const size_t num_processes,
double cutoff_prob = 1.0,
const size_t cutoff_top_n = 40,
Scorer *ext_scorer = NULL);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullptr

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (dictionary != nullptr) delete static_cast<fst::StdVectorFst*>(dictionary);
}

void Scorer::load_LM(const char* filename) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_lm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful of the indent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -21,9 +21,10 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also update this in examples/tiny.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator Author

@kuke kuke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed comments. Please continue to review

@@ -176,6 +176,7 @@ Data augmentation has often been a highly effective technique to boost the deep

Six optional augmentation components are provided to be selected, configured and inserted into the processing pipeline.

### Inference
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cutoff_prob=1.0,
cutoff_top_n=40,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a param used by Mandarin vocabulary cutoff

prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in xrange(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
cutoff_len = min(cutoff_top_n, cutoff_top_n)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It is not necessary.
  2. Corrected

@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
nproc)
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would append the comment in later pr.

#include "decoder_utils.h"
#include "path_trie.h"

std::string ctc_greedy_decoder(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -0,0 +1,21 @@
#!/bin/bash

if [ ! -d kenlm ]; then
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will append later

language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -21,9 +21,10 @@ python -u infer.py \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=0.36 \
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

void reset_params(float alpha, float beta);

// make ngram
std::vector<std::string> make_ngram(PathTrie *prefix);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that there would be an error when using const reference

ext_scorer->fill_dictionary(true);
}
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处应该不是一个简单的指针赋值,而是根据状态不同返回不同的match type:
http://www.openfst.org/doxygen/fst/html/matcher_8h_source.html#l00041

@pkuyym
Copy link
Contributor

pkuyym commented Sep 18, 2017

Great job! LGTM

@kuke
Copy link
Collaborator Author

kuke commented Sep 18, 2017

Resolve #277

@kuke kuke merged commit 17ebb40 into PaddlePaddle:develop Sep 18, 2017
@fanlu
Copy link

fanlu commented Oct 14, 2017

@kuke 在swig目录中执行sh setup.sh
报了一个错
Install decoders ...
decoder_utils.h:55: Error: Syntax error in input(1).
running install
最终安装成功了
Processing dependencies for swig-decoders==0.1
Finished processing dependencies for swig-decoders==0.1
但是执行python -c "import swig_decoders"还是报错
Traceback (most recent call last):
File "", line 1, in
ImportError: No module named swig_decoders

@kuke
Copy link
Collaborator Author

kuke commented Oct 14, 2017

decoder_utils.h:55: Error: Syntax error in input(1)
This error results from that the version of swig is too low. Please upgrade swig first then reinstall the decoders.
@fanlu

@fanlu
Copy link

fanlu commented Oct 14, 2017

@kuke 非常感谢,对swig不了解,另外还有两个问题,1.在mac上执行会报这个问题,
openfst-1.6.3/src/include/fst/types.h:19:10: fatal error: 'cstdint' file not found
2.另外,在中文的处理中,ctc_beam_search_decoder.cpp 第122行,c == space_id 中文是没有空格的,怎么把语言模型的转移概率加进去呢?

@lcy-seso
Copy link
Collaborator

@fanlu hi~ 你好,考虑到这个PR已经merge,也不易被其它人搜索到。能否提一个issue,把所有遇到的问题都汇总到issue中呢?

@fanlu
Copy link

fanlu commented Oct 14, 2017

@lcy-seso 好的

@fanlu fanlu mentioned this pull request Oct 14, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants